def pipeline(subgraph_max_nodes): dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name) input_dim = dataset.num_node_features output_dim = dataset.num_classes data = dataset[0] node_indices = torch.where(data.test_mask * data.y != 0)[0] gnnNets = GnnNets_NC(input_dim, output_dim, model_args) checkpoint = torch.load(mcts_args.explain_model_path) gnnNets.update_state_dict(checkpoint['net']) gnnNets.to_device() gnnNets.eval() save_dir = os.path.join('./results', f"{mcts_args.dataset_name}" f"_{model_args.model_name}" f"_{reward_args.reward_method}") if not os.path.isdir(save_dir): os.mkdir(save_dir) plotutils = PlotUtils(dataset_name=data_args.dataset_name) fidelity_score_list = [] sparsity_score_list = [] for node_idx in tqdm(node_indices): # find the paths and build the graph result_path = os.path.join(save_dir, f"node_{node_idx}_score.pt") # get data and prediction logits, prob, _ = gnnNets(data.clone()) _, prediction = torch.max(prob, -1) prediction = prediction[node_idx].item() # build the graph for visualization graph = to_networkx(data, to_undirected=True) node_labels = {k: int(v) for k, v in enumerate(data.y)} nx.set_node_attributes(graph, node_labels, 'label') # searching using gnn score mcts_state_map = MCTS(node_idx=node_idx, ori_graph=graph, X=data.x, edge_index=data.edge_index, num_hops=len(model_args.latent_dim), n_rollout=mcts_args.rollout, min_atoms=mcts_args.min_atoms, c_puct=mcts_args.c_puct, expand_atoms=mcts_args.expand_atoms) value_func = GnnNets_NC2value_func(gnnNets, node_idx=mcts_state_map.node_idx, target_class=prediction) score_func = reward_func(reward_args, value_func) mcts_state_map.set_score_func(score_func) # get searching result if os.path.isfile(result_path): gnn_results = torch.load(result_path) else: gnn_results = mcts_state_map.mcts(verbose=True) torch.save(gnn_results, result_path) tree_node_x = find_closest_node_result(gnn_results, subgraph_max_nodes) # calculate the metrics original_node_list = [i for i in tree_node_x.ori_graph.nodes] masked_node_list = [i for i in tree_node_x.ori_graph.nodes if i not in tree_node_x.coalition or i == mcts_state_map.node_idx] original_score = gnn_score(original_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method='zero_filling') masked_score = gnn_score(masked_node_list, tree_node_x.data, value_func=value_func, subgraph_building_method='zero_filling') sparsity_score = 1 - len(tree_node_x.coalition)/tree_node_x.ori_graph.number_of_nodes() fidelity_score_list.append(original_score - masked_score) sparsity_score_list.append(sparsity_score) # visualization subgraph_node_labels = nx.get_node_attributes(tree_node_x.ori_graph, name='label') subgraph_node_labels = torch.tensor([v for k, v in subgraph_node_labels.items()]) plotutils.plot(tree_node_x.ori_graph, tree_node_x.coalition, y=subgraph_node_labels, node_idx=mcts_state_map.node_idx, figname=os.path.join(save_dir, f"node_{node_idx}.png")) fidelity_scores = torch.tensor(fidelity_score_list) sparsity_scores = torch.tensor(sparsity_score_list) return fidelity_scores, sparsity_scores
def pipeline(max_nodes): dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name) plotutils = PlotUtils(dataset_name=data_args.dataset_name) input_dim = dataset.num_node_features output_dim = dataset.num_classes if data_args.dataset_name == 'mutag': data_indices = list(range(len(dataset))) else: loader = get_dataloader(dataset, batch_size=train_args.batch_size, random_split_flag=data_args.random_split, data_split_ratio=data_args.data_split_ratio, seed=data_args.seed) data_indices = loader['test'].dataset.indices gnnNets = GnnNets(input_dim, output_dim, model_args) checkpoint = torch.load(mcts_args.explain_model_path) gnnNets.update_state_dict(checkpoint['net']) gnnNets.to_device() gnnNets.eval() save_dir = os.path.join( './results', f"{mcts_args.dataset_name}_" f"{model_args.model_name}_" f"{reward_args.reward_method}") if not os.path.isdir(save_dir): os.mkdir(save_dir) fidelity_score_list = [] sparsity_score_list = [] for i in tqdm(data_indices): # get data and prediction data = dataset[i] _, probs, _ = gnnNets(Batch.from_data_list([data.clone()])) prediction = probs.squeeze().argmax(-1).item() original_score = probs.squeeze()[prediction] # get the reward func value_func = GnnNets_GC2value_func(gnnNets, target_class=prediction) payoff_func = reward_func(reward_args, value_func) # find the paths and build the graph result_path = os.path.join(save_dir, f"example_{i}.pt") # mcts for l_shapely mcts_state_map = MCTS(data.x, data.edge_index, score_func=payoff_func, n_rollout=mcts_args.rollout, min_atoms=mcts_args.min_atoms, c_puct=mcts_args.c_puct, expand_atoms=mcts_args.expand_atoms) if os.path.isfile(result_path): results = torch.load(result_path) else: results = mcts_state_map.mcts(verbose=True) torch.save(results, result_path) # l sharply score graph_node_x = find_closest_node_result(results, max_nodes=max_nodes) masked_node_list = [ node for node in list(range(graph_node_x.data.x.shape[0])) if node not in graph_node_x.coalition ] fidelity_score = original_score - gnn_score( masked_node_list, data, value_func, subgraph_building_method='zero_filling') sparsity_score = 1 - len( graph_node_x.coalition) / graph_node_x.ori_graph.number_of_nodes() fidelity_score_list.append(fidelity_score) sparsity_score_list.append(sparsity_score) # visualization if hasattr(dataset, 'supplement'): words = dataset.supplement['sentence_tokens'][str(i)] plotutils.plot(graph_node_x.ori_graph, graph_node_x.coalition, words=words, figname=os.path.join(save_dir, f"example_{i}.png")) else: plotutils.plot(graph_node_x.ori_graph, graph_node_x.coalition, x=graph_node_x.data.x, figname=os.path.join(save_dir, f"example_{i}.png")) fidelity_scores = torch.tensor(fidelity_score_list) sparsity_scores = torch.tensor(sparsity_score_list) return fidelity_scores, sparsity_scores