def log_mask(self, epoch): plt.switch_backend("agg") fig = plt.figure(figsize=(4, 3), dpi=400) plt.imshow(self.mask.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) cbar = plt.colorbar() cbar.solids.set_edgecolor("face") plt.tight_layout() fig.canvas.draw() self.writer.add_image( "mask/mask", tensorboardX.utils.figure_to_image(fig), epoch ) # fig = plt.figure(figsize=(4,3), dpi=400) # plt.imshow(self.feat_mask.cpu().detach().numpy()[:,np.newaxis], cmap=plt.get_cmap('BuPu')) # cbar = plt.colorbar() # cbar.solids.set_edgecolor("face") # plt.tight_layout() # fig.canvas.draw() # self.writer.add_image('mask/feat_mask', tensorboardX.utils.figure_to_image(fig), epoch) io_utils.log_matrix( self.writer, torch.sigmoid(self.feat_mask), "mask/feat_mask", epoch ) fig = plt.figure(figsize=(4, 3), dpi=400) # use [0] to remove the batch dim plt.imshow(self.masked_adj[0].cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) cbar = plt.colorbar() cbar.solids.set_edgecolor("face") plt.tight_layout() fig.canvas.draw() self.writer.add_image( "mask/adj", tensorboardX.utils.figure_to_image(fig), epoch ) if self.args.mask_bias: fig = plt.figure(figsize=(4, 3), dpi=400) # use [0] to remove the batch dim plt.imshow(self.mask_bias.cpu().detach().numpy(), cmap=plt.get_cmap("BuPu")) cbar = plt.colorbar() cbar.solids.set_edgecolor("face") plt.tight_layout() fig.canvas.draw() self.writer.add_image( "mask/bias", tensorboardX.utils.figure_to_image(fig), epoch )
def log_adj_grad(self, node_idx, pred_label, epoch, label=None): log_adj = False if self.graph_mode: predicted_label = pred_label # adj_grad, x_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[0] adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) adj_grad = torch.abs(adj_grad)[0] x_grad = torch.sum(x_grad[0], 0, keepdim=True).t() else: predicted_label = pred_label[node_idx] # adj_grad = torch.abs(self.adj_feat_grad(node_idx, predicted_label)[0])[self.graph_idx] adj_grad, x_grad = self.adj_feat_grad(node_idx, predicted_label) adj_grad = torch.abs(adj_grad)[self.graph_idx] x_grad = x_grad[self.graph_idx][node_idx][:, np.newaxis] # x_grad = torch.sum(x_grad[self.graph_idx], 0, keepdim=True).t() adj_grad = (adj_grad + adj_grad.t()) / 2 adj_grad = (adj_grad * self.adj).squeeze() if log_adj: io_utils.log_matrix(self.writer, adj_grad, "grad/adj_masked", epoch) self.adj.requires_grad = False io_utils.log_matrix(self.writer, self.adj.squeeze(), "grad/adj_orig", epoch) masked_adj = self.masked_adj[0].cpu().detach().numpy() # only for graph mode since many node neighborhoods for syn tasks are relatively large for # visualization if self.graph_mode: G = io_utils.denoise_graph( masked_adj, node_idx, feat=self.x[0], threshold=None, max_component=False ) io_utils.log_graph( self.writer, G, name="grad/graph_orig", epoch=epoch, identify_self=False, label_node_feat=True, nodecolor="feat", edge_vmax=None, args=self.args, ) io_utils.log_matrix(self.writer, x_grad, "grad/feat", epoch) adj_grad = adj_grad.detach().numpy() if self.graph_mode: print("GRAPH model") G = io_utils.denoise_graph( adj_grad, node_idx, feat=self.x[0], threshold=0.0003, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name="grad/graph", epoch=epoch, identify_self=False, label_node_feat=True, nodecolor="feat", edge_vmax=None, args=self.args, ) else: # G = io_utils.denoise_graph(adj_grad, node_idx, label=label, threshold=0.5) G = io_utils.denoise_graph(adj_grad, node_idx, threshold_num=12) io_utils.log_graph( self.writer, G, name="grad/graph", epoch=epoch, args=self.args )
def explain( self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp" ): """Explain a single node prediction """ # index of the query node in the new adj if graph_mode: node_idx_new = node_idx sub_adj = self.adj[graph_idx] sub_feat = self.feat[graph_idx, :] sub_label = self.label[graph_idx] neighbors = np.asarray(range(self.adj.shape[0])) else: print("node label: ", self.label[graph_idx][node_idx]) node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( node_idx, graph_idx ) print("neigh graph idx: ", node_idx, node_idx_new) sub_label = np.expand_dims(sub_label, axis=0) sub_adj = np.expand_dims(sub_adj, axis=0) sub_feat = np.expand_dims(sub_feat, axis=0) adj = torch.tensor(sub_adj, dtype=torch.float) x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) label = torch.tensor(sub_label, dtype=torch.long) if self.graph_mode: pred_label = np.argmax(self.pred[0][graph_idx], axis=0) print("Graph predicted label: ", pred_label) else: pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) print("Node predicted label: ", pred_label[node_idx_new]) explainer = ExplainModule( adj=adj, x=x, model=self.model, label=label, args=self.args, writer=self.writer, graph_idx=self.graph_idx, graph_mode=self.graph_mode, ) if self.args.gpu: explainer = explainer.cuda() self.model.eval() # gradient baseline if model == "grad": explainer.zero_grad() # pdb.set_trace() adj_grad = torch.abs( explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0] )[graph_idx] masked_adj = adj_grad + adj_grad.t() masked_adj = nn.functional.sigmoid(masked_adj) masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze() else: explainer.train() begin_time = time.time() for epoch in range(self.args.num_epochs): explainer.zero_grad() explainer.optimizer.zero_grad() ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained) loss = explainer.loss(ypred, pred_label, node_idx_new, epoch) loss.backward() explainer.optimizer.step() if explainer.scheduler is not None: explainer.scheduler.step() mask_density = explainer.mask_density() if self.print_training: print( "epoch: ", epoch, "; loss: ", loss.item(), "; mask density: ", mask_density.item(), "; pred: ", ypred, ) single_subgraph_label = sub_label.squeeze() if self.writer is not None: self.writer.add_scalar("mask/density", mask_density, epoch) self.writer.add_scalar( "optimization/lr", explainer.optimizer.param_groups[0]["lr"], epoch, ) if epoch % 25 == 0: explainer.log_mask(epoch) explainer.log_masked_adj( node_idx_new, epoch, label=single_subgraph_label ) explainer.log_adj_grad( node_idx_new, pred_label, epoch, label=single_subgraph_label ) if epoch == 0: if self.model.att: # explain node print("adj att size: ", adj_atts.size()) adj_att = torch.sum(adj_atts[0], dim=2) # adj_att = adj_att[neighbors][:, neighbors] node_adj_att = adj_att * adj.float().cuda() io_utils.log_matrix( self.writer, node_adj_att[0], "att/matrix", epoch ) node_adj_att = node_adj_att[0].cpu().detach().numpy() G = io_utils.denoise_graph( node_adj_att, node_idx_new, threshold=3.8, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name="att/graph", identify_self=not self.graph_mode, nodecolor="label", edge_vmax=None, args=self.args, ) if model != "exp": break print("finished training in ", time.time() - begin_time) if model == "exp": masked_adj = ( explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze() ) else: adj_atts = nn.functional.sigmoid(adj_atts).squeeze() masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze() fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + ( 'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy') with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: np.save(outfile, np.asarray(masked_adj.copy())) print("Saved adjacency matrix to ", fname) return masked_adj
def explain_nodes(self, node_indices, args, graph_idx=0): """ Explain nodes Args: - node_indices : Indices of the nodes to be explained - args : Program arguments (mainly for logging paths) - graph_idx : Index of the graph to explain the nodes from (if multiple). """ masked_adjs = [ self.explain(node_idx, graph_idx=graph_idx) for node_idx in node_indices ] ref_idx = node_indices[0] ref_adj = masked_adjs[0] curr_idx = node_indices[1] curr_adj = masked_adjs[1] new_ref_idx, _, ref_feat, _, _ = self.extract_neighborhood(ref_idx) new_curr_idx, _, curr_feat, _, _ = self.extract_neighborhood(curr_idx) G_ref = io_utils.denoise_graph(ref_adj, new_ref_idx, ref_feat, threshold=0.1) denoised_ref_feat = np.array( [G_ref.nodes[node]["feat"] for node in G_ref.nodes()] ) denoised_ref_adj = nx.to_numpy_matrix(G_ref) # ref center node ref_node_idx = list(G_ref.nodes()).index(new_ref_idx) G_curr = io_utils.denoise_graph( curr_adj, new_curr_idx, curr_feat, threshold=0.1 ) denoised_curr_feat = np.array( [G_curr.nodes[node]["feat"] for node in G_curr.nodes()] ) denoised_curr_adj = nx.to_numpy_matrix(G_curr) # curr center node curr_node_idx = list(G_curr.nodes()).index(new_curr_idx) P, aligned_adj, aligned_feat = self.align( denoised_ref_feat, denoised_ref_adj, ref_node_idx, denoised_curr_feat, denoised_curr_adj, curr_node_idx, args=args, ) io_utils.log_matrix(self.writer, P, "align/P", 0) G_ref = nx.convert_node_labels_to_integers(G_ref) io_utils.log_graph(self.writer, G_ref, "align/ref") G_curr = nx.convert_node_labels_to_integers(G_curr) io_utils.log_graph(self.writer, G_curr, "align/before") P = P.cpu().detach().numpy() aligned_adj = aligned_adj.cpu().detach().numpy() aligned_feat = aligned_feat.cpu().detach().numpy() aligned_idx = np.argmax(P[:, curr_node_idx]) print("aligned self: ", aligned_idx) G_aligned = io_utils.denoise_graph( aligned_adj, aligned_idx, aligned_feat, threshold=0.5 ) io_utils.log_graph(self.writer, G_aligned, "mask/aligned") # io_utils.log_graph(self.writer, aligned_adj.cpu().detach().numpy(), new_curr_idx, # 'align/aligned', epoch=1) return masked_adjs