Beispiel #1
0
def pipeline(subgraph_max_nodes):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes
    data = dataset[0]
    node_indices = torch.where(data.test_mask * data.y != 0)[0]

    gnnNets = GnnNets_NC(input_dim, output_dim, model_args)
    checkpoint = torch.load(mcts_args.explain_model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()
    save_dir = os.path.join('./results', f"{mcts_args.dataset_name}"
                                         f"_{model_args.model_name}"
                                         f"_{reward_args.reward_method}")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    fidelity_score_list = []
    sparsity_score_list = []
    for node_idx in tqdm(node_indices):
        # find the paths and build the graph
        result_path = os.path.join(save_dir, f"node_{node_idx}_score.pt")

        # get data and prediction
        logits, prob,  _ = gnnNets(data.clone())
        _, prediction = torch.max(prob, -1)
        prediction = prediction[node_idx].item()

        # build the graph for visualization
        graph = to_networkx(data, to_undirected=True)
        node_labels = {k: int(v) for k, v in enumerate(data.y)}
        nx.set_node_attributes(graph, node_labels, 'label')

        #  searching using gnn score
        mcts_state_map = MCTS(node_idx=node_idx, ori_graph=graph,
                              X=data.x, edge_index=data.edge_index,
                              num_hops=len(model_args.latent_dim),
                              n_rollout=mcts_args.rollout,
                              min_atoms=mcts_args.min_atoms,
                              c_puct=mcts_args.c_puct,
                              expand_atoms=mcts_args.expand_atoms)
        value_func = GnnNets_NC2value_func(gnnNets,
                                           node_idx=mcts_state_map.node_idx,
                                           target_class=prediction)
        score_func = reward_func(reward_args, value_func)
        mcts_state_map.set_score_func(score_func)

        # get searching result
        if os.path.isfile(result_path):
            gnn_results = torch.load(result_path)
        else:
            gnn_results = mcts_state_map.mcts(verbose=True)
            torch.save(gnn_results, result_path)
        tree_node_x = find_closest_node_result(gnn_results, subgraph_max_nodes)

        # calculate the metrics
        original_node_list = [i for i in tree_node_x.ori_graph.nodes]
        masked_node_list = [i for i in tree_node_x.ori_graph.nodes
                            if i not in tree_node_x.coalition or i == mcts_state_map.node_idx]
        original_score = gnn_score(original_node_list, tree_node_x.data,
                                   value_func=value_func, subgraph_building_method='zero_filling')
        masked_score = gnn_score(masked_node_list, tree_node_x.data,
                                 value_func=value_func, subgraph_building_method='zero_filling')
        sparsity_score = 1 - len(tree_node_x.coalition)/tree_node_x.ori_graph.number_of_nodes()

        fidelity_score_list.append(original_score - masked_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        subgraph_node_labels = nx.get_node_attributes(tree_node_x.ori_graph, name='label')
        subgraph_node_labels = torch.tensor([v for k, v in subgraph_node_labels.items()])
        plotutils.plot(tree_node_x.ori_graph, tree_node_x.coalition, y=subgraph_node_labels,
                       node_idx=mcts_state_map.node_idx,
                       figname=os.path.join(save_dir, f"node_{node_idx}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores
Beispiel #2
0
def pipeline(max_nodes):
    dataset = get_dataset(data_args.dataset_dir, data_args.dataset_name)
    plotutils = PlotUtils(dataset_name=data_args.dataset_name)
    input_dim = dataset.num_node_features
    output_dim = dataset.num_classes

    if data_args.dataset_name == 'mutag':
        data_indices = list(range(len(dataset)))
    else:
        loader = get_dataloader(dataset,
                                batch_size=train_args.batch_size,
                                random_split_flag=data_args.random_split,
                                data_split_ratio=data_args.data_split_ratio,
                                seed=data_args.seed)
        data_indices = loader['test'].dataset.indices

    gnnNets = GnnNets(input_dim, output_dim, model_args)
    checkpoint = torch.load(mcts_args.explain_model_path)
    gnnNets.update_state_dict(checkpoint['net'])
    gnnNets.to_device()
    gnnNets.eval()

    save_dir = os.path.join(
        './results', f"{mcts_args.dataset_name}_"
        f"{model_args.model_name}_"
        f"{reward_args.reward_method}")
    if not os.path.isdir(save_dir):
        os.mkdir(save_dir)

    fidelity_score_list = []
    sparsity_score_list = []
    for i in tqdm(data_indices):
        # get data and prediction
        data = dataset[i]
        _, probs, _ = gnnNets(Batch.from_data_list([data.clone()]))
        prediction = probs.squeeze().argmax(-1).item()
        original_score = probs.squeeze()[prediction]

        # get the reward func
        value_func = GnnNets_GC2value_func(gnnNets, target_class=prediction)
        payoff_func = reward_func(reward_args, value_func)

        # find the paths and build the graph
        result_path = os.path.join(save_dir, f"example_{i}.pt")

        # mcts for l_shapely
        mcts_state_map = MCTS(data.x,
                              data.edge_index,
                              score_func=payoff_func,
                              n_rollout=mcts_args.rollout,
                              min_atoms=mcts_args.min_atoms,
                              c_puct=mcts_args.c_puct,
                              expand_atoms=mcts_args.expand_atoms)

        if os.path.isfile(result_path):
            results = torch.load(result_path)
        else:
            results = mcts_state_map.mcts(verbose=True)
            torch.save(results, result_path)

        # l sharply score
        graph_node_x = find_closest_node_result(results, max_nodes=max_nodes)
        masked_node_list = [
            node for node in list(range(graph_node_x.data.x.shape[0]))
            if node not in graph_node_x.coalition
        ]
        fidelity_score = original_score - gnn_score(
            masked_node_list,
            data,
            value_func,
            subgraph_building_method='zero_filling')
        sparsity_score = 1 - len(
            graph_node_x.coalition) / graph_node_x.ori_graph.number_of_nodes()
        fidelity_score_list.append(fidelity_score)
        sparsity_score_list.append(sparsity_score)

        # visualization
        if hasattr(dataset, 'supplement'):
            words = dataset.supplement['sentence_tokens'][str(i)]
            plotutils.plot(graph_node_x.ori_graph,
                           graph_node_x.coalition,
                           words=words,
                           figname=os.path.join(save_dir, f"example_{i}.png"))
        else:
            plotutils.plot(graph_node_x.ori_graph,
                           graph_node_x.coalition,
                           x=graph_node_x.data.x,
                           figname=os.path.join(save_dir, f"example_{i}.png"))

    fidelity_scores = torch.tensor(fidelity_score_list)
    sparsity_scores = torch.tensor(sparsity_score_list)
    return fidelity_scores, sparsity_scores