def explain_graphs(self, graph_indices): """ Explain graphs. """ masked_adjs = [] for graph_idx in graph_indices: masked_adj = self.explain(node_idx=0, graph_idx=graph_idx, graph_mode=True) G_denoised = io_utils.denoise_graph( masked_adj, 0, threshold_num=20, feat=self.feat[graph_idx], max_component=False, ) label = self.label[graph_idx] io_utils.log_graph( self.writer, G_denoised, "graph/graphidx_{}_label={}".format(graph_idx, label), identify_self=False, nodecolor="feat", ) masked_adjs.append(masked_adj) G_orig = io_utils.denoise_graph( self.adj[graph_idx], 0, feat=self.feat[graph_idx], threshold=None, max_component=False, ) # io_utils.log_graph( # self.writer, # G_orig, # "graph/graphidx_{}".format(graph_idx), # identify_self=False, # nodecolor="feat", # ) # plot cmap for graphs' node features io_utils.plot_cmap_tb(self.writer, "tab20", 20, "tab20_cmap") return masked_adjs
def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp"): masked_adjs = [ self.explain(node_idx, graph_idx=graph_idx, model=model) for node_idx in node_indices ] # pdb.set_trace() graphs = [] feats = [] adjs = [] pred_all = [] real_all = [] for i, idx in enumerate(node_indices): new_idx, _, feat, _, _ = self.extract_neighborhood(idx) G = io_utils.denoise_graph(masked_adjs[i], new_idx, feat, threshold_num=20) pred, real = self.make_pred_real(masked_adjs[i], new_idx) pred_all.append(pred) real_all.append(real) denoised_feat = np.array([G.nodes[node]["feat"] for node in G.nodes()]) denoised_adj = nx.to_numpy_matrix(G) graphs.append(G) feats.append(denoised_feat) adjs.append(denoised_adj) io_utils.log_graph( self.writer, G, "graph/{}_{}_{}".format(self.args.dataset, model, i), identify_self=True, ) pred_all = np.concatenate((pred_all), axis=0) real_all = np.concatenate((real_all), axis=0) auc_all = roc_auc_score(real_all, pred_all) precision, recall, thresholds = precision_recall_curve(real_all, pred_all) plt.switch_backend("agg") plt.plot(recall, precision) plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") plt.close() auc_all = roc_auc_score(real_all, pred_all) precision, recall, thresholds = precision_recall_curve(real_all, pred_all) plt.switch_backend("agg") plt.plot(recall, precision) plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") plt.close() with open("log/pr/auc_" + self.args.dataset + "_" + model + ".txt", "w") as f: f.write( "dataset: {}, model: {}, auc: {}\n".format( self.args.dataset, "exp", str(auc_all) ) ) return masked_adjs
def log_masked_adj(self, node_idx, epoch, name="mask/graph", label=None): # use [0] to remove the batch dim masked_adj = self.masked_adj[0].cpu().detach().numpy() if self.graph_mode: G = io_utils.denoise_graph( masked_adj, node_idx, feat=self.x[0], threshold=0.2, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name=name, identify_self=False, nodecolor="feat", epoch=epoch, label_node_feat=True, edge_vmax=None, args=self.args, ) else: G = io_utils.denoise_graph(masked_adj, node_idx, threshold_num=12, max_component=True) io_utils.log_graph( self.writer, G, name=name, identify_self=True, nodecolor="label", epoch=epoch, edge_vmax=None, args=self.args, )
def log_adj_grad(self, node_idx, pred_label, epoch, label=None): log_adj = False if self.graph_mode: predicted_label = pred_label # adj_grad, x_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[0] adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) adj_grad = torch.abs(adj_grad)[0] x_grad = torch.sum(x_grad[0], 0, keepdim=True).t() else: predicted_label = pred_label[node_idx] # adj_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[self.graph_idx] adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) adj_grad = torch.abs(adj_grad)[self.graph_idx] x_grad = x_grad[self.graph_idx][node_idx][:, np.newaxis] # x_grad = torch.sum(x_grad[self.graph_idx], 0, keepdim=True).t() adj_grad = (adj_grad + adj_grad.t()) / 2 adj_grad = (adj_grad * self.adj).squeeze() if log_adj: io_utils.log_matrix(self.writer, adj_grad, "grad/adj_masked", epoch) self.adj.requires_grad = False io_utils.log_matrix(self.writer, self.adj.squeeze(), "grad/adj_orig", epoch) masked_adj = self.masked_adj[0].cpu().detach().numpy() # only for graph mode since many node neighborhoods for syn tasks are relatively large for # visualization if self.graph_mode: G = io_utils.denoise_graph( masked_adj, node_idx, feat=self.x[0], threshold=None, max_component=False ) io_utils.log_graph( self.writer, G, name="grad/graph_orig", epoch=epoch, identify_self=False, label_node_feat=True, nodecolor="feat", edge_vmax=None, args=self.args, ) io_utils.log_matrix(self.writer, x_grad, "grad/feat", epoch) adj_grad = adj_grad.detach().numpy() if self.graph_mode: print("GRAPH model") G = io_utils.denoise_graph( adj_grad, node_idx, feat=self.x[0], threshold=0.0003, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name="grad/graph", epoch=epoch, identify_self=False, label_node_feat=True, nodecolor="feat", edge_vmax=None, args=self.args, ) else: # G = io_utils.denoise_graph(adj_grad, node_idx, label=label, threshold=0.5) G = io_utils.denoise_graph(adj_grad, node_idx, threshold_num=12) io_utils.log_graph( self.writer, G, name="grad/graph", epoch=epoch, args=self.args )
def explain( self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp" ): """Explain a single node prediction """ # index of the query node in the new adj if graph_mode: node_idx_new = node_idx sub_adj = self.adj[graph_idx] sub_feat = self.feat[graph_idx, :] sub_label = self.label[graph_idx] neighbors = np.asarray(range(self.adj.shape[0])) else: print("node label: ", self.label[graph_idx][node_idx]) node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( node_idx, graph_idx ) print("neigh graph idx: ", node_idx, node_idx_new) sub_label = np.expand_dims(sub_label, axis=0) sub_adj = np.expand_dims(sub_adj, axis=0) sub_feat = np.expand_dims(sub_feat, axis=0) adj = torch.tensor(sub_adj, dtype=torch.float) x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) label = torch.tensor(sub_label, dtype=torch.long) if self.graph_mode: pred_label = np.argmax(self.pred[0][graph_idx], axis=0) print("Graph predicted label: ", pred_label) else: pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) print("Node predicted label: ", pred_label[node_idx_new]) explainer = ExplainModule( adj=adj, x=x, model=self.model, label=label, args=self.args, writer=self.writer, graph_idx=self.graph_idx, graph_mode=self.graph_mode, ) if self.args.gpu: explainer = explainer.cuda() self.model.eval() # gradient baseline if model == "grad": explainer.zero_grad() # pdb.set_trace() adj_grad = torch.abs( explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0] )[graph_idx] masked_adj = adj_grad + adj_grad.t() masked_adj = nn.functional.sigmoid(masked_adj) masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze() else: explainer.train() begin_time = time.time() for epoch in range(self.args.num_epochs): explainer.zero_grad() explainer.optimizer.zero_grad() ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained) loss = explainer.loss(ypred, pred_label, node_idx_new, epoch) loss.backward() explainer.optimizer.step() if explainer.scheduler is not None: explainer.scheduler.step() mask_density = explainer.mask_density() if self.print_training: print( "epoch: ", epoch, "; loss: ", loss.item(), "; mask density: ", mask_density.item(), "; pred: ", ypred, ) single_subgraph_label = sub_label.squeeze() if self.writer is not None: self.writer.add_scalar("mask/density", mask_density, epoch) self.writer.add_scalar( "optimization/lr", explainer.optimizer.param_groups[0]["lr"], epoch, ) if epoch % 25 == 0: explainer.log_mask(epoch) explainer.log_masked_adj( node_idx_new, epoch, label=single_subgraph_label ) explainer.log_adj_grad( node_idx_new, pred_label, epoch, label=single_subgraph_label ) if epoch == 0: if self.model.att: # explain node print("adj att size: ", adj_atts.size()) adj_att = torch.sum(adj_atts[0], dim=2) # adj_att = adj_att[neighbors][:, neighbors] node_adj_att = adj_att * adj.float().cuda() io_utils.log_matrix( self.writer, node_adj_att[0], "att/matrix", epoch ) node_adj_att = node_adj_att[0].cpu().detach().numpy() G = io_utils.denoise_graph( node_adj_att, node_idx_new, threshold=3.8, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name="att/graph", identify_self=not self.graph_mode, nodecolor="label", edge_vmax=None, args=self.args, ) if model != "exp": break print("finished training in ", time.time() - begin_time) if model == "exp": masked_adj = ( explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze() ) else: adj_atts = nn.functional.sigmoid(adj_atts).squeeze() masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze() fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + ( 'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy') with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: np.save(outfile, np.asarray(masked_adj.copy())) print("Saved adjacency matrix to ", fname) return masked_adj
def explain_nodes(self, node_indices, args, graph_idx=0): """ Explain nodes Args: - node_indices : Indices of the nodes to be explained - args : Program arguments (mainly for logging paths) - graph_idx : Index of the graph to explain the nodes from (if multiple). """ masked_adjs = [ self.explain(node_idx, graph_idx=graph_idx) for node_idx in node_indices ] ref_idx = node_indices[0] ref_adj = masked_adjs[0] curr_idx = node_indices[1] curr_adj = masked_adjs[1] new_ref_idx, _, ref_feat, _, _ = self.extract_neighborhood(ref_idx) new_curr_idx, _, curr_feat, _, _ = self.extract_neighborhood(curr_idx) G_ref = io_utils.denoise_graph(ref_adj, new_ref_idx, ref_feat, threshold=0.1) denoised_ref_feat = np.array( [G_ref.nodes[node]["feat"] for node in G_ref.nodes()] ) denoised_ref_adj = nx.to_numpy_matrix(G_ref) # ref center node ref_node_idx = list(G_ref.nodes()).index(new_ref_idx) G_curr = io_utils.denoise_graph( curr_adj, new_curr_idx, curr_feat, threshold=0.1 ) denoised_curr_feat = np.array( [G_curr.nodes[node]["feat"] for node in G_curr.nodes()] ) denoised_curr_adj = nx.to_numpy_matrix(G_curr) # curr center node curr_node_idx = list(G_curr.nodes()).index(new_curr_idx) P, aligned_adj, aligned_feat = self.align( denoised_ref_feat, denoised_ref_adj, ref_node_idx, denoised_curr_feat, denoised_curr_adj, curr_node_idx, args=args, ) io_utils.log_matrix(self.writer, P, "align/P", 0) G_ref = nx.convert_node_labels_to_integers(G_ref) io_utils.log_graph(self.writer, G_ref, "align/ref") G_curr = nx.convert_node_labels_to_integers(G_curr) io_utils.log_graph(self.writer, G_curr, "align/before") P = P.cpu().detach().numpy() aligned_adj = aligned_adj.cpu().detach().numpy() aligned_feat = aligned_feat.cpu().detach().numpy() aligned_idx = np.argmax(P[:, curr_node_idx]) print("aligned self: ", aligned_idx) G_aligned = io_utils.denoise_graph( aligned_adj, aligned_idx, aligned_feat, threshold=0.5 ) io_utils.log_graph(self.writer, G_aligned, "mask/aligned") # io_utils.log_graph(self.writer, aligned_adj.cpu().detach().numpy(), new_curr_idx, # 'align/aligned', epoch=1) return masked_adjs
def explain_nodes_gnn_stats(self, node_indices, args, graph_idx=0, model="exp", K=10): start = time.time() masked_adjs = [ self.explain(node_idx, graph_idx=graph_idx, model=model) for node_idx in node_indices ] # Define number of edges in specific the shape introduced k = 12 if self.args.dataset == 'syn5' else 6 end = time.time() print('GNNE Time:', end - start) # pdb.set_trace() graphs = [] feats = [] adjs = [] pred_all = [] real_all = [] node_imp = [] for i, idx in enumerate(node_indices): new_idx, _, feat, _, _ = self.extract_neighborhood(idx) G = io_utils.denoise_graph(masked_adjs[i], new_idx, feat, threshold_num=20) pred, real = self.make_pred_real(masked_adjs[i], new_idx) pred_all.append(pred) real_all.append(real) denoised_feat = np.array( [G.nodes[node]["feat"] for node in G.nodes()]) denoised_adj = nx.to_numpy_matrix(G) graphs.append(G) feats.append(denoised_feat) adjs.append(denoised_adj) # Vizu io_utils.log_graph( self.writer, G, "graph/{}_{}_{}".format(self.args.dataset, model, i), identify_self=True, ) # Compute importance of nodes based on all incident edges (av. imp) #n = np.concatenate((np.mean(denoised_adj, axis=1)), axis=1).tolist()[0] n = [] for row in denoised_adj: n.append(row[row != 0].mean()) # Check among (k-2) most important nodes, how many are member of the shape if args.dataset == 'syn4': K = 6 if args.dataset == 'syn5': K = 8 else: K = 5 node_imp.append( len( set(np.array(G.nodes())[np.argsort(n)[-K:]]).intersection( set(range(new_idx + 1, new_idx + K)))) / (K - 1)) #list_of_imp_nodes.append(list(np.array(G.nodes())[np.argsort(n)[-K:]])) # Also look at top 6 edges (because cycle - adapt to grid dataset) # Compute accuracy: how many of top 6 belong to shape accuracy = [] for obs, real_obs in zip(pred_all, real_all): accuracy.append(np.sum(real_obs[np.argsort(obs)[-k:]]) / k) pred_all = np.concatenate((pred_all), axis=0) real_all = np.concatenate((real_all), axis=0) auc_all = roc_auc_score(real_all, pred_all) precision, recall, thresholds = precision_recall_curve( real_all, pred_all) plt.switch_backend("agg") plt.plot(recall, precision) plt.savefig("log/pr/pr_" + self.args.dataset + "_" + model + ".png") plt.close() with open("log/pr/auc_" + self.args.dataset + "_" + model + ".txt", "w") as f: f.write("dataset: {}, model: {}, auc: {}\n".format( self.args.dataset, "exp", str(auc_all))) return masked_adjs, accuracy, auc_all, node_imp