def recon_loss(z, pos_edge_index, neg_edge_index=None): r"""Given latent variables :obj:`z`, computes the binary cross entropy loss for positive edges :obj:`pos_edge_index` and negative sampled edges. Args: z (Tensor): The latent space :math:`\mathbf{Z}`. pos_edge_index (LongTensor): The positive edges to train against. neg_edge_index (LongTensor, optional): The negative edges to train against. If not given, uses negative sampling to calculate negative edges. (default: :obj:`None`) """ EPS = 1e-15 decoder = InnerProductDecoder() pos_loss = -torch.log(decoder(z, pos_edge_index, sigmoid=True) + EPS).mean() # Do not include self-loops in negative samples pos_edge_index, _ = remove_self_loops(pos_edge_index) pos_edge_index, _ = add_self_loops(pos_edge_index) if neg_edge_index is None: neg_edge_index = negative_sampling(pos_edge_index, z.size(0)) neg_loss = -torch.log(1 - decoder(z, neg_edge_index, sigmoid=True) + EPS).mean() return pos_loss + neg_loss
def forward(self, x: Tensor, edge_index: Tensor, **kwargs ): # --- run the model once --- super().forward(x=x, edge_index=edge_index, **kwargs) self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) if data_args.model_level == 'node': node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # --- add shap calculation hook --- shap = DeepLift(self.model) self.model.apply(shap._register_hooks) inp_with_ref = torch.cat([x, torch.zeros(x.shape, device=self.device, dtype=torch.float)], dim=0).requires_grad_(True) edge_index_with_ref = torch.cat([edge_index, edge_index + x.shape[0]], dim=1) batch = torch.arange(2, dtype=torch.long, device=self.device).view(2, 1).repeat(1, x.shape[0]).reshape(-1) out = self.model(inp_with_ref, edge_index_with_ref, batch) labels = tuple(i for i in range(data_args.num_classes)) ex_labels = tuple(torch.tensor([label]).to(data_args.device) for label in labels) print('#D#Mask Calculate...') masks = [] for ex_label in ex_labels: if self.explain_graph: f = torch.unbind(out[:, ex_label]) else: f = torch.unbind(out[[node_idx, node_idx + x.shape[0]], ex_label]) (m, ) = torch.autograd.grad(outputs=f, inputs=inp_with_ref, retain_graph=True) inp, inp_ref = torch.chunk(inp_with_ref, 2) attr_wo_relu = (torch.chunk(m, 2)[0] * (inp - inp_ref)).sum(1) mask = attr_wo_relu.squeeze() mask = (mask[self_loop_edge_index[0]] + mask[self_loop_edge_index[1]]) / 2 mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) # Store related predictions for further evaluation. shap._remove_hooks() print('#D#Predict...') with torch.no_grad(): with self.connect_mask(self): related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return None, masks, related_preds
def forward(self, x, edge_index, mask_features=False, positive=True, **kwargs): r"""Learns and returns a node feature mask and an edge mask that play a crucial role to explain the prediction made by the GNN for node :attr:`node_idx`. Args: data (Batch): batch from dataloader edge_index (LongTensor): The edge indices. pos_neg (Literal['pos', 'neg']) : get positive or negative mask **kwargs (optional): Additional arguments passed to the GNN module. :rtype: (:class:`Tensor`, :class:`Tensor`) """ self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) # Only operate on a k-hop subgraph around `node_idx`. # Get subgraph and relabel the node, mapping is the relabeled given node_idx. if data_args.model_level == 'node': node_idx = kwargs.get('node_idx') self.node_idx = node_idx assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # Assume the mask we will predict labels = tuple(i for i in range(data_args.num_classes)) ex_labels = tuple(torch.tensor([label]).to(data_args.device) for label in labels) # Calculate mask print('#D#Masks calculate...') edge_masks = [] for ex_label in ex_labels: self.__clear_masks__() self.__set_masks__(x, self_loop_edge_index) edge_masks.append(self.control_sparsity(self.gnn_explainer_alg(x, edge_index, ex_label), sparsity=kwargs.get('sparsity'))) # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label)) print('#D#Predict...') with torch.no_grad(): related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs) self.__clear_masks__() return None, edge_masks, related_preds
def visualize_walks(self, node_idx, edge_index, walks, edge_mask, y=None, threshold=None, **kwargs) -> Tuple[Axes, nx.DiGraph]: r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask :attr:`edge_mask`. Args: node_idx (int): The node id to explain. edge_index (LongTensor): The edge indices. edge_mask (Tensor): The edge mask. y (Tensor, optional): The ground-truth node-prediction labels used as node colorings. (default: :obj:`None`) threshold (float, optional): Sets a threshold for visualizing important edges. If set to :obj:`None`, will visualize all edges with transparancy indicating the importance of edges. (default: :obj:`None`) **kwargs (optional): Additional arguments passed to :func:`nx.draw`. :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph` """ self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=kwargs.get('num_nodes')) assert edge_mask.size(0) == self_loop_edge_index.size(1) if self.molecule: atomic_num = torch.clone(y) # Only operate on a k-hop subgraph around `node_idx`. subset, edge_index, _, hard_edge_mask = subgraph( node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) edge_mask = edge_mask[hard_edge_mask] # --- temp --- edge_mask[edge_mask == float('inf')] = 1 edge_mask[edge_mask == - float('inf')] = 0 # --- if threshold is not None: edge_mask = (edge_mask >= threshold).to(torch.float) if data_args.dataset_name == 'ba_lrp': y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) if y is None: y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) else: y = y[subset] if self.molecule: atom_colors = {6: '#8c69c5', 7: '#71bcf0', 8: '#aef5f1', 9: '#bdc499', 15: '#c22f72', 16: '#f3ea19', 17: '#bdc499', 35: '#cc7161'} node_colors = [None for _ in range(y.shape[0])] for y_idx in range(y.shape[0]): node_colors[y_idx] = atom_colors[y[y_idx].int().tolist()] else: atom_colors = {0: '#8c69c5', 1: '#c56973', 2: '#a1c569', 3: '#69c5ba'} node_colors = [None for _ in range(y.shape[0])] for y_idx in range(y.shape[0]): node_colors[y_idx] = atom_colors[y[y_idx].int().tolist()] data = Data(edge_index=edge_index, att=edge_mask, y=y, num_nodes=y.size(0)).to('cpu') G = to_networkx(data, node_attrs=['y'], edge_attrs=['att']) mapping = {k: i for k, i in enumerate(subset.tolist())} G = nx.relabel_nodes(G, mapping) kwargs['with_labels'] = kwargs.get('with_labels') or True kwargs['font_size'] = kwargs.get('font_size') or 8 kwargs['node_size'] = kwargs.get('node_size') or 200 kwargs['cmap'] = kwargs.get('cmap') or 'cool' # calculate Graph positions pos = nx.kamada_kawai_layout(G) ax = plt.gca() for source, target, data in G.edges(data=True): ax.annotate( '', xy=pos[target], xycoords='data', xytext=pos[source], textcoords='data', arrowprops=dict( arrowstyle="-", lw=1.5, alpha=0.5, # alpha control transparency color='grey', # color control color shrinkA=sqrt(kwargs['node_size']) / 2.0, shrinkB=sqrt(kwargs['node_size']) / 2.0, connectionstyle="arc3,rad=0", # rad control angle )) # --- try to draw a walk --- walks_ids = walks['ids'] walks_score = walks['score'] walks_node_list = [] for i in range(walks_ids.shape[1]): if i == 0: walks_node_list.append(self_loop_edge_index[:, walks_ids[:, i].view(-1)].view(2, -1)) else: walks_node_list.append(self_loop_edge_index[1, walks_ids[:, i].view(-1)].view(1, -1)) walks_node_ids = torch.cat(walks_node_list, dim=0).T walks_mask = torch.zeros(walks_node_ids.shape, dtype=bool, device=self.device) for n in G.nodes(): walks_mask = walks_mask | (walks_node_ids == n) walks_mask = walks_mask.sum(1) == walks_node_ids.shape[1] sub_walks_node_ids = walks_node_ids[walks_mask] sub_walks_score = walks_score[walks_mask] for i, walk in enumerate(sub_walks_node_ids): verts = [pos[n.item()] for n in walk] if walk.shape[0] == 3: codes = [Path.MOVETO, Path.CURVE3, Path.CURVE3] else: codes = [Path.MOVETO, Path.CURVE4, Path.CURVE4, Path.CURVE4] path = Path(verts, codes) if sub_walks_score[i] > 0: patch = PathPatch(path, facecolor='none', edgecolor='red', lw=1.5,#e1442a alpha=(sub_walks_score[i] / (sub_walks_score.max() * 2)).item()) else: patch = PathPatch(path, facecolor='none', edgecolor='blue', lw=1.5,#18d66b alpha=(sub_walks_score[i] / (sub_walks_score.min() * 2)).item()) ax.add_patch(patch) nx.draw_networkx_nodes(G, pos, node_color=node_colors, **kwargs) # define node labels if self.molecule: if x_args.nolabel: node_labels = {n: f'{self.table(atomic_num[n].int().item())}' for n in G.nodes()} nx.draw_networkx_labels(G, pos, labels=node_labels, **kwargs) else: node_labels = {n: f'{n}:{self.table(atomic_num[n].int().item())}' for n in G.nodes()} nx.draw_networkx_labels(G, pos, labels=node_labels, **kwargs) else: if not x_args.nolabel: nx.draw_networkx_labels(G, pos, **kwargs) return ax, G
def visualize_graph(self, node_idx, edge_index, edge_mask, y=None, threshold=None, **kwargs) -> Tuple[Axes, nx.DiGraph]: r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask :attr:`edge_mask`. Args: node_idx (int): The node id to explain. edge_index (LongTensor): The edge indices. edge_mask (Tensor): The edge mask. y (Tensor, optional): The ground-truth node-prediction labels used as node colorings. (default: :obj:`None`) threshold (float, optional): Sets a threshold for visualizing important edges. If set to :obj:`None`, will visualize all edges with transparancy indicating the importance of edges. (default: :obj:`None`) **kwargs (optional): Additional arguments passed to :func:`nx.draw`. :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph` """ edge_index, _ = add_self_loops(edge_index, num_nodes=kwargs.get('num_nodes')) assert edge_mask.size(0) == edge_index.size(1) if self.molecule: atomic_num = torch.clone(y) # Only operate on a k-hop subgraph around `node_idx`. subset, edge_index, _, hard_edge_mask = subgraph( node_idx, self.__num_hops__, edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) edge_mask = edge_mask[hard_edge_mask] # --- temp --- edge_mask[edge_mask == float('inf')] = 1 edge_mask[edge_mask == - float('inf')] = 0 # --- if threshold is not None: edge_mask = (edge_mask >= threshold).to(torch.float) if data_args.dataset_name == 'ba_lrp': y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) if y is None: y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device) else: y = y[subset] if self.molecule: atom_colors = {6: '#8c69c5', 7: '#71bcf0', 8: '#aef5f1', 9: '#bdc499', 15: '#c22f72', 16: '#f3ea19', 17: '#bdc499', 35: '#cc7161'} node_colors = [None for _ in range(y.shape[0])] for y_idx in range(y.shape[0]): node_colors[y_idx] = atom_colors[y[y_idx].int().tolist()] else: atom_colors = {0: '#8c69c5', 1: '#c56973', 2: '#a1c569', 3: '#69c5ba'} node_colors = [None for _ in range(y.shape[0])] for y_idx in range(y.shape[0]): node_colors[y_idx] = atom_colors[y[y_idx].int().tolist()] data = Data(edge_index=edge_index, att=edge_mask, y=y, num_nodes=y.size(0)).to('cpu') G = to_networkx(data, node_attrs=['y'], edge_attrs=['att']) mapping = {k: i for k, i in enumerate(subset.tolist())} G = nx.relabel_nodes(G, mapping) kwargs['with_labels'] = kwargs.get('with_labels') or True kwargs['font_size'] = kwargs.get('font_size') or 10 kwargs['node_size'] = kwargs.get('node_size') or 250 kwargs['cmap'] = kwargs.get('cmap') or 'cool' # calculate Graph positions pos = nx.kamada_kawai_layout(G) ax = plt.gca() for source, target, data in G.edges(data=True): ax.annotate( '', xy=pos[target], xycoords='data', xytext=pos[source], textcoords='data', arrowprops=dict( arrowstyle="->", lw=max(data['att'], 0.5) * 2, alpha=max(data['att'], 0.4), # alpha control transparency color='#e1442a', # color control color shrinkA=sqrt(kwargs['node_size']) / 2.0, shrinkB=sqrt(kwargs['node_size']) / 2.0, connectionstyle="arc3,rad=0.08", # rad control angle )) nx.draw_networkx_nodes(G, pos, node_color=node_colors, **kwargs) # define node labels if self.molecule: if x_args.nolabel: node_labels = {n: f'{self.table(atomic_num[n].int().item())}' for n in G.nodes()} nx.draw_networkx_labels(G, pos, labels=node_labels, **kwargs) else: node_labels = {n: f'{n}:{self.table(atomic_num[n].int().item())}' for n in G.nodes()} nx.draw_networkx_labels(G, pos, labels=node_labels, **kwargs) else: if not x_args.nolabel: nx.draw_networkx_labels(G, pos, **kwargs) return ax, G
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj, edge_weight: OptTensor = None, **kwargs) -> Tensor: """""" self.num_nodes = x.shape[0] if isinstance(x, Tensor): x: OptPairTensor = (x, x) # propagate_type: (x: OptPairTensor) if edge_weight is not None: self.edge_weight = edge_weight assert edge_weight.shape[0] == edge_index.shape[1] self.reweight = False else: edge_index, _ = remove_self_loops(edge_index) self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) if self_loop_edge_index.shape[1] != edge_index.shape[1]: edge_index = self_loop_edge_index self.reweight = True out = self.propagate(edge_index, x=x[0], size=None) if data_args.task == 'explain': layer_extractor = [] hooks = [] def register_hook(module: nn.Module): if not list(module.children()): hooks.append(module.register_forward_hook(forward_hook)) def forward_hook(module: nn.Module, input: Tuple[Tensor], output: Tensor): # input contains x and edge_index layer_extractor.append((module, input[0], output)) # --- register hooks --- self.nn.apply(register_hook) nn_out = self.nn(out) for hook in hooks: hook.remove() fc_steps = [] step = {'input': None, 'module': [], 'output': None} for layer in layer_extractor: if isinstance(layer[0], nn.Linear): if step['module']: fc_steps.append(step) # step = {'input': layer[1], 'module': [], 'output': None} step = {'input': None, 'module': [], 'output': None} step['module'].append(layer[0]) if kwargs.get('probe'): step['output'] = layer[2] else: step['output'] = None if step['module']: fc_steps.append(step) self.fc_steps = fc_steps else: nn_out = self.nn(out) return nn_out
def forward(self, x: Tensor, edge_index: Tensor, **kwargs ): super().forward(x, edge_index, **kwargs) self.model.eval() walk_steps, fc_steps = self.extract_step(x, edge_index, detach=False, split_fc=True) edge_index_with_loop, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) walk_indices_list = torch.tensor( self.walks_pick(edge_index_with_loop.cpu(), list(range(edge_index_with_loop.shape[1])), num_layers=self.num_layers), device=self.device) if data_args.model_level == 'node': node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, edge_index_with_loop, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # walk indices list mask edge2node_idx = edge_index_with_loop[1] == node_idx walk_indices_list_mask = edge2node_idx[walk_indices_list[:, -1]] walk_indices_list = walk_indices_list[walk_indices_list_mask] def compute_walk_score(): # hyper-parameter gamma epsilon = 1e-30 # prevent from zero division gamma = [2, 1, 1] # --- record original weights of GNN --- ori_gnn_weights = [] gnn_gamma_modules = [] clear_probe = x for i, walk_step in enumerate(walk_steps): modules = walk_step['module'] gamma_ = gamma[i] if i <= 1 else 1 if hasattr(modules[0], 'nn'): clear_probe = modules[0](clear_probe, edge_index, probe=False) # clear nodes that are not created by user gamma_module = copy.deepcopy(modules[0]) if hasattr(modules[0], 'nn'): for i, fc_step in enumerate(gamma_module.fc_steps): fc_modules = fc_step['module'] if hasattr(fc_modules[0], 'weight'): ori_fc_weight = fc_modules[0].weight.data fc_modules[0].weight.data = ori_fc_weight + gamma_ * ori_fc_weight else: ori_gnn_weights.append(modules[0].weight.data) gamma_module.weight.data = ori_gnn_weights[i] + gamma_ * ori_gnn_weights[i].relu() gnn_gamma_modules.append(gamma_module) # --- record original weights of fc layer --- ori_fc_weights = [] fc_gamma_modules = [] for i, fc_step in enumerate(fc_steps): modules = fc_step['module'] gamma_module = copy.deepcopy(modules[0]) if hasattr(modules[0], 'weight'): ori_fc_weights.append(modules[0].weight.data) gamma_ = 1 gamma_module.weight.data = ori_fc_weights[i] + gamma_ * ori_fc_weights[i].relu() else: ori_fc_weights.append(None) fc_gamma_modules.append(gamma_module) # --- GNN_LRP implementation --- for walk_indices in walk_indices_list: walk_node_indices = [edge_index_with_loop[0, walk_indices[0]]] for walk_idx in walk_indices: walk_node_indices.append(edge_index_with_loop[1, walk_idx]) h = x.requires_grad_(True) for i, walk_step in enumerate(walk_steps): modules = walk_step['module'] if hasattr(modules[0], 'nn'): # for the specific 2-layer nn GINs. gin = modules[0] run1 = gin(h, edge_index, probe=True) std_h1 = gin.fc_steps[0]['output'] gamma_run1 = gnn_gamma_modules[i](h, edge_index, probe=True) p1 = gnn_gamma_modules[i].fc_steps[0]['output'] q1 = (p1 + epsilon) * (std_h1 / (p1 + epsilon)).detach() std_h2 = GraphSequential(*gin.fc_steps[1]['module'])(q1) p2 = GraphSequential(*gnn_gamma_modules[i].fc_steps[1]['module'])(q1) q2 = (p2 + epsilon) * (std_h2 / (p2 + epsilon)).detach() q = q2 else: std_h = GraphSequential(*modules)(h, edge_index) # --- LRP-gamma --- p = gnn_gamma_modules[i](h, edge_index) q = (p + epsilon) * (std_h / (p + epsilon)).detach() # --- pick a path --- mk = torch.zeros((h.shape[0], 1), device=self.device) k = walk_node_indices[i + 1] mk[k] = 1 ht = q * mk + q.detach() * (1 - mk) h = ht # --- FC LRP_gamma --- for i, fc_step in enumerate(fc_steps): modules = fc_step['module'] std_h = nn.Sequential(*modules)(h) if i != 0 \ else GraphSequential(*modules)(h, torch.zeros(h.shape[0], dtype=torch.long, device=self.device)) # --- gamma --- s = fc_gamma_modules[i](h) if i != 0 \ else fc_gamma_modules[i](h, torch.zeros(h.shape[0], dtype=torch.long, device=self.device)) ht = (s + epsilon) * (std_h / (s + epsilon)).detach() h = ht if data_args.model_level == 'node': f = h[node_idx, label] else: f = h[0, label] x_grads = torch.autograd.grad(outputs=f, inputs=x)[0] I = walk_node_indices[0] r = x_grads[I, :] @ x[I].T walk_scores.append(r) labels = tuple(i for i in range(data_args.num_classes)) walk_scores_tensor_list = [None for i in labels] for label in labels: walk_scores = [] compute_walk_score() walk_scores_tensor_list[label] = torch.stack(walk_scores, dim=0).view(-1, 1) walks = {'ids': walk_indices_list, 'score': torch.cat(walk_scores_tensor_list, dim=1)} # --- Debug --- # walk_node_indices_list = [] # for walk_indices in walk_indices_list: # walk_node_indices = [edge_index_with_loop[0, walk_indices[0]]] # for walk_idx in walk_indices: # walk_node_indices.append(edge_index_with_loop[1, walk_idx]) # walk_node_indices_list.append(torch.stack(walk_node_indices)) # walk_node_indices_list = torch.stack(walk_node_indices_list, dim=0) # --- Debug end --- # --- Apply edge mask evaluation --- with torch.no_grad(): with self.connect_mask(self): ex_labels = tuple(torch.tensor([label]).to(data_args.device) for label in labels) masks = [] for ex_label in ex_labels: edge_attr = self.explain_edges_with_loop(x, walks, ex_label) mask = edge_attr mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return walks, masks, related_preds
def forward(self, x: Tensor, edge_index: Tensor, **kwargs ): super().forward(x, edge_index, **kwargs) self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) walk_steps, fc_step = self.extract_step(x, edge_index, detach=False) if data_args.model_level == 'node': node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) def compute_walk_score(adjs, r, allow_edges, walk_idx=[]): if not adjs: walk_indices.append(walk_idx) walk_scores.append(r.detach()) return (grads,) = torch.autograd.grad(outputs=r, inputs=adjs[0], create_graph=True) for i in allow_edges: allow_edges= torch.where(self_loop_edge_index[1] == self_loop_edge_index[0][i])[0].tolist() new_r = grads[i] * adjs[0][i] compute_walk_score(adjs[1:], new_r, allow_edges, [i] + walk_idx) labels = tuple(i for i in range(data_args.num_classes)) walk_scores_tensor_list = [None for i in labels] for label in labels: if self.explain_graph: f = torch.unbind(fc_step['output'][0, label].unsqueeze(0)) allow_edges = [i for i in range(self_loop_edge_index.shape[1])] else: f = torch.unbind(fc_step['output'][node_idx, label].unsqueeze(0)) allow_edges = torch.where(self_loop_edge_index[1] == node_idx)[0].tolist() adjs = [walk_step['module'][0].edge_weight for walk_step in walk_steps] reverse_adjs = adjs.reverse() walk_indices = [] walk_scores = [] compute_walk_score(adjs, f, allow_edges) walk_scores_tensor_list[label] = torch.stack(walk_scores, dim=0).view(-1, 1) walks = {'ids': torch.tensor(walk_indices, device=self.device), 'score': torch.cat(walk_scores_tensor_list, dim=1)} # --- Apply edge mask evaluation --- with torch.no_grad(): with self.connect_mask(self): ex_labels = tuple(torch.tensor([label]).to(data_args.device) for label in labels) masks = [] for ex_label in ex_labels: edge_attr = self.explain_edges_with_loop(x, walks, ex_label) mask = edge_attr mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return walks, masks, related_preds
def forward(self, x, edge_index, mask_features=False, **kwargs): r""" Run the explainer for a specific graph instance. Args: x (torch.Tensor): The graph instance's input node features. edge_index (torch.Tensor): The graph instance's edge index. mask_features (bool, optional): Whether to use feature mask. Not recommended. (Default: :obj:`False`) **kwargs (dict): :obj:`node_idx` (int): The index of node that is pending to be explained. (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: :obj:`0.7`) :obj:`num_classes` (int): The number of task's classes. :rtype: (None, list, list) .. note:: (None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities. """ super().forward(x=x, edge_index=edge_index, **kwargs) self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) # Only operate on a k-hop subgraph around `node_idx`. # Get subgraph and relabel the node, mapping is the relabeled given node_idx. if not self.explain_graph: node_idx = kwargs.get('node_idx') self.node_idx = node_idx assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph(node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # Assume the mask we will predict labels = tuple(i for i in range(kwargs.get('num_classes'))) ex_labels = tuple( torch.tensor([label]).to(self.device) for label in labels) # Calculate mask edge_masks = [] for ex_label in ex_labels: self.__clear_masks__() self.__set_masks__(x, self_loop_edge_index) edge_masks.append( self.control_sparsity(self.gnn_explainer_alg( x, edge_index, ex_label), sparsity=kwargs.get('sparsity'))) # edge_masks.append(self.gnn_explainer_alg(x, edge_index, ex_label)) with torch.no_grad(): related_preds = self.eval_related_pred(x, edge_index, edge_masks, **kwargs) self.__clear_masks__() return None, edge_masks, related_preds
def forward(self, x: Tensor, edge_index: Tensor, **kwargs)\ -> Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]]: r""" Run the explainer for a specific graph instance. Args: x (torch.Tensor): The graph instance's input node features. edge_index (torch.Tensor): The graph instance's edge index. **kwargs (dict): :obj:`node_idx` (int): The index of node that is pending to be explained. (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: :obj:`0.7`) :obj:`num_classes` (int): The number of task's classes. :rtype: (None, list, list) .. note:: (None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities. """ self.model.eval() super().forward(x, edge_index) labels = tuple(i for i in range(kwargs.get('num_classes'))) ex_labels = tuple( torch.tensor([label]).to(self.device) for label in labels) self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) if not self.explain_graph: node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph(node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # --- setting GradCAM --- class model_node(nn.Module): def __init__(self, cls): super().__init__() self.cls = cls self.convs = cls.model.convs def forward(self, *args, **kwargs): return self.cls.model(*args, **kwargs)[node_idx].unsqueeze(0) if self.explain_graph: model = self.model else: model = model_node(self) self.explain_method = GraphLayerGradCam(model, model.convs[-1]) # --- setting end --- masks = [] for ex_label in ex_labels: attr_wo_relu = self.explain_method.attribute( x, ex_label, additional_forward_args=edge_index) mask = normalize(attr_wo_relu.relu()) mask = mask.squeeze() mask = (mask[self_loop_edge_index[0]] + mask[self_loop_edge_index[1]]) / 2 mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) # Store related predictions for further evaluation. with torch.no_grad(): with self.connect_mask(self): related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return None, masks, related_preds
def forward(self, x: Tensor, edge_index: Tensor, **kwargs): r""" Run the explainer for a specific graph instance. Args: x (torch.Tensor): The graph instance's input node features. edge_index (torch.Tensor): The graph instance's edge index. **kwargs (dict): :obj:`node_idx` (int): The index of node that is pending to be explained. (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: :obj:`0.7`) :rtype: (None, list, list) .. note:: (None, edge_masks, related_predictions): edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities. """ # --- run the model once --- super().forward(x=x, edge_index=edge_index, **kwargs) self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) if not self.explain_graph: node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph(node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # --- add shap calculation hook --- shap = DeepLift(self.model) self.model.apply(shap._register_hooks) inp_with_ref = torch.cat( [x, torch.zeros(x.shape, device=self.device, dtype=torch.float)], dim=0).requires_grad_(True) edge_index_with_ref = torch.cat([edge_index, edge_index + x.shape[0]], dim=1) batch = torch.arange(2, dtype=torch.long, device=self.device).view( 2, 1).repeat(1, x.shape[0]).reshape(-1) out = self.model(inp_with_ref, edge_index_with_ref, batch) labels = tuple(i for i in range(kwargs.get('num_classes'))) ex_labels = tuple( torch.tensor([label]).to(self.device) for label in labels) masks = [] for ex_label in ex_labels: if self.explain_graph: f = torch.unbind(out[:, ex_label]) else: f = torch.unbind(out[[node_idx, node_idx + x.shape[0]], ex_label]) (m, ) = torch.autograd.grad(outputs=f, inputs=inp_with_ref, retain_graph=True) inp, inp_ref = torch.chunk(inp_with_ref, 2) attr_wo_relu = (torch.chunk(m, 2)[0] * (inp - inp_ref)).sum(1) mask = attr_wo_relu.squeeze() mask = (mask[self_loop_edge_index[0]] + mask[self_loop_edge_index[1]]) / 2 mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) # Store related predictions for further evaluation. shap._remove_hooks() with torch.no_grad(): with self.connect_mask(self): related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return None, masks, related_preds
def forward(self, x: Tensor, edge_index: Tensor, **kwargs ): r""" Run the explainer for a specific graph instance. Args: x (torch.Tensor): The graph instance's input node features. edge_index (torch.Tensor): The graph instance's edge index. **kwargs (dict): :obj:`node_idx` (int): The index of node that is pending to be explained. (for node classification) :obj:`sparsity` (float): The Sparsity we need to control to transform a soft mask to a hard mask. (Default: :obj:`0.7`) :obj:`num_classes` (int): The number of task's classes. :rtype: (dict, list, list) .. note:: (walks, edge_masks, related_predictions): walks is a dictionary including walks' edge indices and corresponding explained scores; edge_masks is a list of edge-level explanation for each class; related_predictions is a list of dictionary for each class where each dictionary includes 4 type predicted probabilities. """ super().forward(x, edge_index, **kwargs) self.model.eval() self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) walk_steps, fc_step = self.extract_step(x, edge_index, detach=False) if not self.explain_graph: node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph( node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) def compute_walk_score(adjs, r, allow_edges, walk_idx=[]): if not adjs: walk_indices.append(walk_idx) walk_scores.append(r.detach()) return (grads,) = torch.autograd.grad(outputs=r, inputs=adjs[0], create_graph=True) for i in allow_edges: allow_edges= torch.where(self_loop_edge_index[1] == self_loop_edge_index[0][i])[0].tolist() new_r = grads[i] * adjs[0][i] compute_walk_score(adjs[1:], new_r, allow_edges, [i] + walk_idx) labels = tuple(i for i in range(kwargs.get('num_classes'))) walk_scores_tensor_list = [None for i in labels] for label in labels: if self.explain_graph: f = torch.unbind(fc_step['output'][0, label].unsqueeze(0)) allow_edges = [i for i in range(self_loop_edge_index.shape[1])] else: f = torch.unbind(fc_step['output'][node_idx, label].unsqueeze(0)) allow_edges = torch.where(self_loop_edge_index[1] == node_idx)[0].tolist() adjs = [walk_step['module'][0].edge_weight for walk_step in walk_steps] reverse_adjs = adjs.reverse() walk_indices = [] walk_scores = [] compute_walk_score(adjs, f, allow_edges) walk_scores_tensor_list[label] = torch.stack(walk_scores, dim=0).view(-1, 1) walks = {'ids': torch.tensor(walk_indices, device=self.device), 'score': torch.cat(walk_scores_tensor_list, dim=1)} # --- Apply edge mask evaluation --- with torch.no_grad(): with self.connect_mask(self): ex_labels = tuple(torch.tensor([label]).to(self.device) for label in labels) masks = [] for ex_label in ex_labels: edge_attr = self.explain_edges_with_loop(x, walks, ex_label) mask = edge_attr mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return walks, masks, related_preds
def forward(self, x: Tensor, edge_index: Tensor, **kwargs)\ -> Union[Tuple[None, List, List[Dict]], Tuple[Dict, List, List[Dict]]]: """ Given a sample, this function will return its predicted masks and corresponding predictions for evaluation :param x: Tensor - Hiden features of all vertexes :param edge_index: Tensor - All connected edge between vertexes/nodes :param kwargs: :return: """ self.model.eval() super().forward(x, edge_index) labels = tuple(i for i in range(data_args.num_classes)) ex_labels = tuple( torch.tensor([label]).to(data_args.device) for label in labels) self_loop_edge_index, _ = add_self_loops(edge_index, num_nodes=self.num_nodes) if data_args.model_level == 'node': node_idx = kwargs.get('node_idx') assert node_idx is not None _, _, _, self.hard_edge_mask = subgraph(node_idx, self.__num_hops__, self_loop_edge_index, relabel_nodes=True, num_nodes=None, flow=self.__flow__()) # --- setting GradCAM --- class model_node(nn.Module): def __init__(self, cls): super().__init__() self.cls = cls self.convs = cls.model.convs def forward(self, *args, **kwargs): return self.cls.model(*args, **kwargs)[node_idx].unsqueeze(0) if self.explain_graph: model = self.model else: model = model_node(self) self.explain_method = GraphLayerGradCam(model, model.convs[-1]) # --- setting end --- print('#D#Mask Calculate...') masks = [] for ex_label in ex_labels: attr_wo_relu = self.explain_method.attribute( x, ex_label, additional_forward_args=edge_index) mask = normalize(attr_wo_relu.relu()) mask = mask.squeeze() mask = (mask[self_loop_edge_index[0]] + mask[self_loop_edge_index[1]]) / 2 mask = self.control_sparsity(mask, kwargs.get('sparsity')) masks.append(mask.detach()) # Store related predictions for further evaluation. print('#D#Predict...') with torch.no_grad(): with self.connect_mask(self): related_preds = self.eval_related_pred(x, edge_index, masks, **kwargs) return None, masks, related_preds