def explain( self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp" ): """Explain a single node prediction """ # index of the query node in the new adj if graph_mode: node_idx_new = node_idx sub_adj = self.adj[graph_idx] sub_feat = self.feat[graph_idx, :] sub_label = self.label[graph_idx] neighbors = np.asarray(range(self.adj.shape[0])) else: print("node label: ", self.label[graph_idx][node_idx]) node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( node_idx, graph_idx ) print("neigh graph idx: ", node_idx, node_idx_new) sub_label = np.expand_dims(sub_label, axis=0) sub_adj = np.expand_dims(sub_adj, axis=0) sub_feat = np.expand_dims(sub_feat, axis=0) adj = torch.tensor(sub_adj, dtype=torch.float) x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) label = torch.tensor(sub_label, dtype=torch.long) if self.graph_mode: pred_label = np.argmax(self.pred[0][graph_idx], axis=0) print("Graph predicted label: ", pred_label) else: pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) print("Node predicted label: ", pred_label[node_idx_new]) explainer = ExplainModule( adj=adj, x=x, model=self.model, label=label, args=self.args, writer=self.writer, graph_idx=self.graph_idx, graph_mode=self.graph_mode, ) if self.args.gpu: explainer = explainer.cuda() self.model.eval() # gradient baseline if model == "grad": explainer.zero_grad() # pdb.set_trace() adj_grad = torch.abs( explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0] )[graph_idx] masked_adj = adj_grad + adj_grad.t() masked_adj = nn.functional.sigmoid(masked_adj) masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze() else: explainer.train() begin_time = time.time() for epoch in range(self.args.num_epochs): explainer.zero_grad() explainer.optimizer.zero_grad() ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained) loss = explainer.loss(ypred, pred_label, node_idx_new, epoch) loss.backward() explainer.optimizer.step() if explainer.scheduler is not None: explainer.scheduler.step() mask_density = explainer.mask_density() if self.print_training: print( "epoch: ", epoch, "; loss: ", loss.item(), "; mask density: ", mask_density.item(), "; pred: ", ypred, ) single_subgraph_label = sub_label.squeeze() if self.writer is not None: self.writer.add_scalar("mask/density", mask_density, epoch) self.writer.add_scalar( "optimization/lr", explainer.optimizer.param_groups[0]["lr"], epoch, ) if epoch % 25 == 0: explainer.log_mask(epoch) explainer.log_masked_adj( node_idx_new, epoch, label=single_subgraph_label ) explainer.log_adj_grad( node_idx_new, pred_label, epoch, label=single_subgraph_label ) if epoch == 0: if self.model.att: # explain node print("adj att size: ", adj_atts.size()) adj_att = torch.sum(adj_atts[0], dim=2) # adj_att = adj_att[neighbors][:, neighbors] node_adj_att = adj_att * adj.float().cuda() io_utils.log_matrix( self.writer, node_adj_att[0], "att/matrix", epoch ) node_adj_att = node_adj_att[0].cpu().detach().numpy() G = io_utils.denoise_graph( node_adj_att, node_idx_new, threshold=3.8, # threshold_num=20, max_component=True, ) io_utils.log_graph( self.writer, G, name="att/graph", identify_self=not self.graph_mode, nodecolor="label", edge_vmax=None, args=self.args, ) if model != "exp": break print("finished training in ", time.time() - begin_time) if model == "exp": masked_adj = ( explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze() ) else: adj_atts = nn.functional.sigmoid(adj_atts).squeeze() masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze() fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + ( 'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy') with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: np.save(outfile, np.asarray(masked_adj.copy())) print("Saved adjacency matrix to ", fname) return masked_adj
def main(): # Load a configuration prog_args = arg_parse() if prog_args.gpu: os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda print("CUDA", prog_args.cuda) else: print("Using CPU") # Configure the logging directory if prog_args.writer: path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args)) if os.path.isdir(path) and prog_args.clean_log: print('Removing existing log dir: ', path) if not input( "Are you sure you want to remove this directory? (y/n): " ).lower().strip()[:1] == "y": sys.exit(1) shutil.rmtree(path) writer = SummaryWriter(path) else: writer = None # Load data and a model checkpoint ckpt = io_utils.load_ckpt(prog_args) cg_dict = ckpt["cg"] # get computation graph input_dim = cg_dict["feat"].shape[2] num_classes = cg_dict["pred"].shape[2] print("Loaded model from {}".format(prog_args.ckptdir)) print("input dim: ", input_dim, "; num classes: ", num_classes) # Determine explainer mode (node classif) graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0 or prog_args.graph_idx >= 0) # build model print("Method: ", prog_args.method) if graph_mode: # Explain Graph prediction model = models.GcnEncoderGraph( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) else: if prog_args.dataset == "ppi_essential": # class weight in CE loss for handling imbalanced label classes prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda() # Explain Node prediction model = models.GcnEncoderNode( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) if prog_args.gpu: model = model.cuda() # Load state_dict (obtained by model.state_dict() when saving checkpoint) model.load_state_dict(ckpt["model_state"]) # Convertion data required to get correct model output for GraphSHAP adj = torch.tensor(cg_dict["adj"], dtype=torch.float) x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float) if prog_args.gpu: y_pred, att_adj = model(x.cuda(), adj.cuda()) else: y_pred, att_adj = model(x, adj) # Transform their data into our format data = transform_data(adj, x, cg_dict["label"][0].tolist()) # Generate test nodes # Use only these specific nodes as they are the ones added manually, part of the defined shapes # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx']) k = 4 # number of nodes for the shape introduced (house, cycle) K = 0 if prog_args.dataset == 'syn1': node_indices = list(range(400, 410, 5)) elif prog_args.dataset == 'syn2': node_indices = list(range(400, 405, 5)) + list(range(1100, 1105, 5)) elif prog_args.dataset == 'syn4': node_indices = list(range(511, 523, 6)) if prog_args.hops == 3: k = 5 else: K = 5 elif prog_args.dataset == 'syn5': node_indices = list(range(511, 529, 9)) if prog_args.hops == 3: k = 7 K = 8 else: k = 5 K = 8 # GraphSHAP explainer # graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset, prog_args.gpu) # Run GNN Explainer and retrieve produced explanations gnne = explain.Explainer( model=model, adj=cg_dict["adj"], feat=cg_dict["feat"], label=cg_dict["label"], pred=cg_dict["pred"], train_idx=cg_dict["train_idx"], args=prog_args, writer=writer, print_training=True, graph_mode=graph_mode, graph_idx=prog_args.graph_idx, ) ### GNNE # Explain a set of nodes - accuracy on edges this time t = time.time() gnne_edge_accuracy, gnne_auc, gnne_node_accuracy, important_nodes_gnne =\ gnne.explain_nodes_gnn_stats( node_indices, prog_args ) e = time.time() print('Time: ', e - t)
def main(): # Load a configuration prog_args = arg_parse() if prog_args.gpu: os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda print("CUDA", prog_args.cuda) else: print("Using CPU") # Configure the logging directory if prog_args.writer: path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args)) if os.path.isdir(path) and prog_args.clean_log: print('Removing existing log dir: ', path) if not input( "Are you sure you want to remove this directory? (y/n): " ).lower().strip()[:1] == "y": sys.exit(1) shutil.rmtree(path) writer = SummaryWriter(path) else: writer = None # Load data and a model checkpoint ckpt = io_utils.load_ckpt(prog_args) cg_dict = ckpt["cg"] # get computation graph input_dim = cg_dict["feat"].shape[2] num_classes = cg_dict["pred"].shape[2] print("Loaded model from {}".format(prog_args.ckptdir)) print("input dim: ", input_dim, "; num classes: ", num_classes) # Determine explainer mode (node classif) graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0 or prog_args.graph_idx >= 0) # build model print("Method: ", prog_args.method) if graph_mode: # Explain Graph prediction model = models.GcnEncoderGraph( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) else: if prog_args.dataset == "ppi_essential": # class weight in CE loss for handling imbalanced label classes prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda() # Explain Node prediction model = models.GcnEncoderNode( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) if prog_args.gpu: model = model.cuda() # Load state_dict (obtained by model.state_dict() when saving checkpoint) model.load_state_dict(ckpt["model_state"]) # Convertion data required to get correct model output for GraphSHAP adj = torch.tensor(cg_dict["adj"], dtype=torch.float) x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float) if prog_args.gpu: y_pred, att_adj = model(x.cuda(), adj.cuda()) else: y_pred, att_adj = model(x, adj) # Transform their data into our format data = transform_data(adj, x, cg_dict["label"][0].tolist()) # Generate test nodes # Use only these specific nodes as they are the ones added manually, part of the defined shapes # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx']) k = 4 # number of nodes for the shape introduced (house, cycle) K = 0 if prog_args.dataset == 'syn1': node_indices = list(range(400, 450, 5)) elif prog_args.dataset == 'syn2': node_indices = list(range(400, 425, 5)) + list(range(1100, 1125, 5)) elif prog_args.dataset == 'syn4': node_indices = list(range(511, 571, 6)) if prog_args.hops == 3: k = 5 else: K = 5 elif prog_args.dataset == 'syn5': node_indices = list(range(511, 601, 9)) if prog_args.hops == 3: k = 8 else: k = 5 K = 8 # GraphSHAP explainer graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset, prog_args.gpu) # Run GNN Explainer and retrieve produced explanations gnne = explain.Explainer( model=model, adj=cg_dict["adj"], feat=cg_dict["feat"], label=cg_dict["label"], pred=cg_dict["pred"], train_idx=cg_dict["train_idx"], args=prog_args, writer=writer, print_training=True, graph_mode=graph_mode, graph_idx=prog_args.graph_idx, ) #if prog_args.explain_node is not None: # _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy = \ # gnne.explain_nodes_gnn_stats( # node_indices, prog_args # ) # elif graph_mode: # # Graph explanation # gnne_expl = gnne.explain_graphs([1])[0] # GraphSHAP - assess accuracy of explanations # Loop over test nodes accuracy = [] feat_accuracy = [] for node_idx in node_indices: start = time.time() graphshap_explanations = graphshap.explain( [node_idx], prog_args.hops, prog_args.num_samples, prog_args.info, prog_args.multiclass, prog_args.fullempty, prog_args.S, prog_args.hv, prog_args.feat, prog_args.coal, prog_args.g, prog_args.regu, )[0] end = time.time() print('GS Time:', end - start) # Predicted class pred_val, predicted_class = y_pred[0, node_idx, :].max(dim=0) # Keep only node explanations # ,predicted_class] graphshap_node_explanations = graphshap_explanations[graphshap.F:] # Derive ground truth from graph structure ground_truth = list(range(node_idx + 1, node_idx + max(k, K) + 1)) # Retrieve top k elements indices form graphshap_node_explanations if graphshap.neighbours.shape[0] > k: i = 0 val, indices = torch.topk( torch.tensor(graphshap_node_explanations.T), k + 1) # could weight importance based on val for node in graphshap.neighbours[indices]: if node.item() in ground_truth: i += 1 # Sort of accruacy metric accuracy.append(i / k) print('There are {} from targeted shape among most imp. nodes'. format(i)) # Look at importance distribution among features # Identify most important features and check if it corresponds to truly imp ones if prog_args.dataset == 'syn2': # ,predicted_class] graphshap_feat_explanations = graphshap_explanations[:graphshap.F] print('Feature importance graphshap', graphshap_feat_explanations.T) if np.argsort(graphshap_feat_explanations)[-1] == 0: feat_accuracy.append(1) else: feat_accuracy.append(0) # Metric for graphshap final_accuracy = sum(accuracy) / len(accuracy) ### GNNE # Explain a set of nodes - accuracy on edges this time _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy =\ gnne.explain_nodes_gnn_stats( node_indices, prog_args ) ### GRAD benchmark # MetricS to assess quality of predictionsx """ _, grad_edge_accuracy, grad_auc, grad_node_accuracy =\ gnne.explain_nodes_gnn_stats( node_indices, prog_args, model="grad") """ grad_edge_accuracy = 0 grad_node_accuracy = 0 ### GAT # Nothing for now - implem a GAT on the side and look at weights coef ### Results print( 'Accuracy for GraphSHAP is {:.2f} vs {:.2f},{:.2f} for GNNE vs {:.2f},{:.2f} for GRAD' .format(final_accuracy, np.mean(gnne_edge_accuracy), np.mean(gnne_node_accuracy), np.mean(grad_edge_accuracy), np.mean(grad_node_accuracy))) if prog_args.dataset == 'syn2': print('Most important feature was found in {:.2f}% of the case'.format( 100 * np.mean(feat_accuracy))) print('GNNE_auc is:', gnne_auc)
def main(): # Load a configuration prog_args = arg_parse() if prog_args.gpu: os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda print("CUDA", prog_args.cuda) else: print("Using CPU") # Configure the logging directory if prog_args.writer: path = os.path.join(prog_args.logdir, io_utils.gen_explainer_prefix(prog_args)) if os.path.isdir(path) and prog_args.clean_log: print('Removing existing log dir: ', path) if not input( "Are you sure you want to remove this directory? (y/n): " ).lower().strip()[:1] == "y": sys.exit(1) shutil.rmtree(path) writer = SummaryWriter(path) else: writer = None # Load a model checkpoint ckpt = io_utils.load_ckpt(prog_args) cg_dict = ckpt["cg"] # get computation graph input_dim = cg_dict["feat"].shape[2] num_classes = cg_dict["pred"].shape[2] print("Loaded model from {}".format(prog_args.ckptdir)) print("input dim: ", input_dim, "; num classes: ", num_classes) # Determine explainer mode graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0 or prog_args.graph_idx >= 0) # build model print("Method: ", prog_args.method) if graph_mode: # Explain Graph prediction model = models.GcnEncoderGraph( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) else: if prog_args.dataset == "ppi_essential": # class weight in CE loss for handling imbalanced label classes prog_args.loss_weight = torch.tensor([1.0, 5.0], dtype=torch.float).cuda() # Explain Node prediction model = models.GcnEncoderNode( input_dim=input_dim, hidden_dim=prog_args.hidden_dim, embedding_dim=prog_args.output_dim, label_dim=num_classes, num_layers=prog_args.num_gc_layers, bn=prog_args.bn, args=prog_args, ) if prog_args.gpu: model = model.cuda() # load state_dict (obtained by model.state_dict() when saving checkpoint) model.load_state_dict(ckpt["model_state"]) # Create explainer explainer = explain.Explainer( model=model, adj=cg_dict["adj"], feat=cg_dict["feat"], label=cg_dict["label"], pred=cg_dict["pred"], train_idx=cg_dict["train_idx"], args=prog_args, writer=writer, print_training=True, graph_mode=graph_mode, graph_idx=prog_args.graph_idx, ) # TODO: API should definitely be cleaner # Let's define exactly which modes we support # We could even move each mode to a different method (even file) if prog_args.explain_node is not None: explainer.explain(prog_args.explain_node, unconstrained=False) elif graph_mode: if prog_args.multigraph_class >= 0: print(cg_dict["label"]) # only run for graphs with label specified by multigraph_class labels = cg_dict["label"].numpy() graph_indices = [] for i, l in enumerate(labels): if l == prog_args.multigraph_class: graph_indices.append(i) if len(graph_indices) > 30: break print( "Graph indices for label ", prog_args.multigraph_class, " : ", graph_indices, ) explainer.explain_graphs(graph_indices=graph_indices) elif prog_args.graph_idx == -1: # just run for a customized set of indices explainer.explain_graphs(graph_indices=[1, 2, 3, 4]) else: explainer.explain( node_idx=0, graph_idx=prog_args.graph_idx, graph_mode=True, unconstrained=False, ) io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap") else: if prog_args.multinode_class >= 0: print(cg_dict["label"]) # only run for nodes with label specified by multinode_class labels = cg_dict["label"][0] # already numpy matrix node_indices = [] for i, l in enumerate(labels): if len(node_indices) > 4: break if l == prog_args.multinode_class: node_indices.append(i) print( "Node indices for label ", prog_args.multinode_class, " : ", node_indices, ) explainer.explain_nodes(node_indices, prog_args) else: # explain a set of nodes masked_adj = explainer.explain_nodes_gnn_stats( range(400, 700, 5), prog_args)
def explain(self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, exp_model="exp"): print('************** Explaining node : {} **************'.format( node_idx)) print('The label for graph index {} and node index {} : {}'.format( graph_idx, node_idx, self.label[graph_idx][node_idx])) print("Labels of all the nodes :\n", self.label) # Adjacency matrix of entire graph print("Shape of retrieved neighborhoods :", self.neighborhoods.shape) print("No. of neighborhoods :", len(self.neighborhoods[graph_idx][node_idx])) print( 'List of neighborhoods for explaining node {} :'.format(node_idx)) print(self.neighborhoods[graph_idx][node_idx]) # index of the query node in the new adj if graph_mode: node_idx_new = node_idx sub_adj = self.adj[graph_idx] sub_feat = self.feat[graph_idx, :] sub_label = self.label[graph_idx] neighbors = np.asarray(range(self.adj.shape[0])) else: print("Ground truth, node label :", self.label[graph_idx][node_idx]) # Computational graph : # Extracting subgraph adjacency matrix, subgraph features, subgraph labels and the nodes neighbours node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood( node_idx, graph_idx) sub_label = np.expand_dims(sub_label, axis=0) sub_adj = np.expand_dims(sub_adj, axis=0) sub_feat = np.expand_dims(sub_feat, axis=0) print("Neighbouring graph index for node " + str(node_idx) + " with new node index " + str(node_idx_new)) #print("Expand dimension of Subgraph adjacency :\n", sub_adj) #print("Expand dimension of Subgraph features :\n", sub_feat) print("Expand dimension of Subgraph label :\n", sub_label) # All the nodes in the graph (eg. indexes from 0 to 34) print("Subgraph neighbors :\n", neighbors) tensor_adj = torch.tensor(sub_adj, dtype=torch.float) tensor_x = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float) tensor_label = torch.tensor(sub_label, dtype=torch.long) if self.graph_mode: pred_label = np.argmax(self.pred[0][graph_idx], axis=0) print("Graph predicted label: ", pred_label) else: pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1) print("Neighbours of predicted node labels :", self.pred[graph_idx][neighbors]) print( 'Predicted labels for all {} neighbours (includes itself) :\n{}' .format(len(pred_label), pred_label)) print('Predicted label for node {} : {}'.format( node_idx, pred_label[node_idx_new])) # Have to use the tensor version of adj for Tensor computation explainerMod = ExplainModule( adj=tensor_adj, # adj x=tensor_x, # x model=self.model, # model label=tensor_label, # label args=self.args, # prog_args writer=self.writer, # None graph_idx=self.graph_idx, # graph_idx graph_mode=self.graph_mode # graph_mode ) self.model.eval() explainerMod.train() begin_time = time.time() # prog_args.num_epochs for epoch in range(self.args.num_epochs): explainerMod.zero_grad() explainerMod.optimizer.zero_grad() # node_idx_new is passed to explainerMod.forward to training with the new index ypred, adj_atts = explainerMod(node_idx_new, unconstrained=unconstrained) loss = explainerMod.loss(ypred, pred_label, node_idx_new, epoch) loss.backward() explainerMod.optimizer.step() mask_density = explainerMod.mask_density() print("epoch: ", epoch, "; loss: ", loss.item(), "; mask density: ", mask_density.item(), "; pred: ", ypred) print( "------------------------------------------------------------------" ) if exp_model != "exp": break print("\n--------------------------------------------") print("Final ypred after training : ", ypred) print("pred_label : ", pred_label) print("node_idx_new : ", node_idx_new) print("Completed training in ", time.time() - begin_time) if exp_model == "exp": masked_adj = (explainerMod.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze()) # Added for plotting node explanation subgraph # explainerMod.mask.shape, masked_edges.shape masked_edges = explainerMod.mask.cpu().detach().numpy() masked_features = explainerMod.feat_mask.cpu().detach().numpy() # explainerMod.feat_mask.shape, masked_features.shape ypred_detach = ypred.cpu().detach().numpy() ypred_node = np.argmax(ypred_detach, axis=0) # labels # ypred = tensor([0.0119, 0.6456, 0.3307, 0.0118] print('Detach ypred : {} and Argmax node : {}'.format( ypred_detach, ypred_node)) # Trained masked, edges and features adjacency matrices print("Shape of masked adjacency matrix : ", masked_adj.shape) print("The masked adjacency matrix at index [0] :\n", masked_adj[0]) print("Shape of masked edges matrix : ", masked_edges.shape) print("The masked edges adjacency matrix at index [0] :\n", masked_edges[0]) print("Shape of masked features matrix : ", masked_features.shape) print("The masked features adjacency matrix at index [0] :\n", masked_features[0]) fname = 'masked_adj_' + io_utils.gen_explainer_prefix( self.args) + ('_node_idx_' + str(node_idx) + '_graph_idx_' + str(self.graph_idx) + '.npy') with open(os.path.join(self.args.logdir, fname), 'wb') as outfile: np.save(outfile, np.asarray(masked_adj.copy())) print("Saved adjacency matrix to \"" + fname + "\".") # PlotSubGraph (sub_edge_index not used) self.PlotSubGraph(masked_adj, masked_edges, node_idx_new, node_idx, feats=sub_feat.squeeze(), labels=tensor_label.cpu().detach().numpy().squeeze(), threshold_num=12, adj_mode=True) # Shape of masked adjacency matrix : (27, 27) # Shape of masked edges matrix : (27, 27) # Shape of masked features matrix : (10,) return masked_adj, masked_edges, masked_features