示例#1
0
def eval_Mutagenicity(data, model, args):
    """Evaluate GraphSVX on MUTAG dataset

    Args:
        data (NameSpace): pre-processed MUTAG dataset
        model (): GNN model
        args (argparse): all parameters
    """
    allgraphs = list(range(len(data.selected)))[100:120]
    accuracy = []
    for graph_idx in allgraphs:
        graphsvx = GraphSVX(data, model, args.gpu)
        graphsvx_explanations = graphsvx.explain_graphs([graph_idx],
                                                        args.hops,
                                                        args.num_samples,
                                                        args.info,
                                                        args.multiclass,
                                                        args.fullempty,
                                                        args.S,
                                                        'graph_classification',
                                                        args.feat,
                                                        args.coal,
                                                        args.g,
                                                        regu=0,
                                                        vizu=False)[0]

        # Find ground truth in orginal data
        idexs = np.nonzero(data.edge_label_lists[graph_idx])[0].tolist()
        inter = []  # retrieve edge g.t. from above indexes of g.t.
        for i in idexs:
            inter.append(data.edge_lists[graph_idx][i])
        ground_truth = [item for sublist in inter for item in sublist]
        ground_truth = list(set(ground_truth))  # node g.t.

        # Find ground truth (nodes) for each graph
        k = len(ground_truth)  # Length gt

        # Retrieve top k elements indices form graphsvx_explanations
        if len(graphsvx.neighbours) >= k:
            i = 0
            val, indices = torch.topk(torch.tensor(
                graphsvx_explanations.T), k)
            # could weight importance based on val
            for node in torch.tensor(graphsvx.neighbours)[indices]:
                if node.item() in ground_truth:
                    i += 1
            # Sort of accruacy metric
            accuracy.append(i / k)
            print('acc:', i/k)
            print('indexes', indices)
            print('gt', ground_truth)

    print('Accuracy', accuracy)
    print('Mean accuracy', np.mean(accuracy))
示例#2
0
def eval_syn6(data, model, args):
    """ Explain and evaluate syn6 dataset
    """
    # Define graphs used for evaluation
    allgraphs = np.nonzero(data.y).T[0].tolist()[:100]

    accuracy = []
    for graph_idx in allgraphs:
        graphsvx = GraphSVX(data, model, args.gpu)
        graphsvx_explanations = graphsvx.explain_graphs([graph_idx],
                                                        args.hops,
                                                        args.num_samples,
                                                        args.info,
                                                        args.multiclass,
                                                        args.fullempty,
                                                        args.S,
                                                        'graph_classification',
                                                        args.feat,
                                                        args.coal,
                                                        args.g,
                                                        regu=0,
                                                        vizu=False)[0]

        # Retrieve ground truth (gt) from data
        preds = []
        reals = []
        ground_truth = list(range(20, 25))

        # Length gt
        k = len(ground_truth)

        # Retrieve top k elements indices form graphsvx_node_explanations
        if len(graphsvx.neighbours) >= k:
            i = 0
            val, indices = torch.topk(torch.tensor(
                graphsvx_explanations.T), k)
            # could weight importance based on val
            for node in torch.tensor(graphsvx.neighbours)[indices]:
                if node.item() in ground_truth:
                    i += 1
            # Sort of accruacy metric
            accuracy.append(i / k)
            print('acc:', i/k)

    print('accuracy', accuracy)
    print('mean', np.mean(accuracy))
示例#3
0
def eval_syn(data, model, args):
    """ Evaluate performance of explainer on synthetic
    datasets with a ground truth 

    Args:
        data (NameSpace): dataset information
        model: trained GNN model
        args (NameSpace): input arguments
    """
    # Define ground truth and test nodes for each dataset
    k = 4  # number of nodes for the shape introduced (house, cycle)
    K = 0
    if args.dataset == 'syn1':
        node_indices = list(range(400, 500, 5))
    elif args.dataset == 'syn2':
        node_indices = list(range(400, 425, 5)) + list(range(1100, 1125, 5))
    elif args.dataset == 'syn4':
        node_indices = list(range(511, 691, 6))  # (511, 571, 6)
        if args.hops == 3:
            k = 5
        else:
            K = 5
    elif args.dataset == 'syn5':
        node_indices = list(range(511, 654, 9))  # (511, 601, 9)
        #k = 7
        k = 5
        K = 7

    # GraphSHAP - assess accuracy of explanations
    graphsvx = GraphSVX(data, model, args.gpu)

    # Loop over test nodes
    accuracy = []
    diff_in_pred = []
    percentage_fidelity = []
    feat_accuracy = []
    for node_idx in node_indices:
        graphsvx_explanations = graphsvx.explain([node_idx],
                                                 args.hops,
                                                 args.num_samples,
                                                 args.info,
                                                 args.multiclass,
                                                 args.fullempty,
                                                 args.S,
                                                 args.hv,
                                                 args.feat,
                                                 args.coal,
                                                 args.g,
                                                 args.regu,
                                                 )[0]

        # Keep only node explanations
        graphsvx_node_explanations = graphsvx_explanations[graphsvx.F:]

        # Derive ground truth from graph structure
        ground_truth = list(range(node_idx+1, node_idx+max(k, K)+1))

        # Retrieve top k elements indices form graphsvx_node_explanations
        l = list(graphsvx.neighbours).index(ground_truth[0])
        if args.info: 
            print('Importance:', np.sum(graphsvx_explanations[l:l+5]))
        #print('Importance:', np.sum(
        #    graphsvx_explanations[l:l+4]) / (np.sum(graphsvx_explanations)-0.01652819)) # base value

        if graphsvx.neighbours.shape[0] > k:
            i = 0
            val, indices = torch.topk(torch.tensor(
                graphsvx_node_explanations.T), k+1)
            # could weight importance based on val
            for node in graphsvx.neighbours[indices]:
                if node.item() in ground_truth:
                    i += 1
            # Sort of accruacy metric
            accuracy.append(i / k)

            if args.info: 
                print('There are {} from targeted shape among most imp. nodes'.format(i))

        # Look at importance distribution among features
        # Identify most important features and check if it corresponds to truly imp ones
        if args.dataset == 'syn2':
            graphsvx_feat_explanations = graphsvx_explanations[:graphsvx.F]
            print('Feature importance graphsvx',
                  graphsvx_feat_explanations.T)
            feat_accuracy.append(len(set(np.argsort(
                graphsvx_feat_explanations)[-2:]).intersection([0, 1])) / 2)

    print('Node Accuracy: {:.2f}, Feature Accuracy: {:.2f}'.format(np.mean(accuracy), \
                                                                    np.mean(feat_accuracy)))