Ejemplo n.º 1
0
def main():
    # Parsing defaults for all program parameters unless provided by user
    prog_args = parse_explainer_args.arg_parse()

    # More params on top of train.py
    prog_args.writer = None  # Check is for None and default is True

    path = os.path.join(prog_args.logdir, io_utils.gen_prefix(prog_args))
    print("Tensorboard writer path :\n", path)
    print("No. of epochs :", prog_args.num_epochs)

    # writer = SummaryWriter(path)

    if prog_args.gpu:
        #    os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        #    env = os.environ.get('CUDA_VISIBLE_DEVICES')
        #    print("Environment is set :", env)
        print('\nCUDA_VISIBLE_DEVICES')
        print('------------------------------------------')
        print("CUDA", prog_args.cuda)
    else:
        print('\n------------------------------------------')
        print("Using CPU")

    # Loading previously saved computational graph data (model checkpoint)
    model_dict = io_utils.load_ckpt(prog_args)
    model_optimizer = model_dict['optimizer']

    print("Model optimizer :", model_optimizer)
    print("Model optimizer state dictionary :\n",
          model_optimizer.state_dict()['param_groups'])
    # model.load_state_dict(checkpoint['model_state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    # epoch = checkpoint['epoch']
    # loss = checkpoint['loss']

    print(
        '------------------------------------------------------------------------------------'
    )
    print("Keys in loaded model dictionary :", list(model_dict))
    print("Keys in loaded model optimizer dictionary:",
          list(model_optimizer.state_dict()))
    print("All loaded labels :\n", model_dict['cg']['label'])

    print()
    print('mask_act:{}, mask_bias:{}, explainer_suffix:{}'.format(
        prog_args.mask_act, prog_args.mask_bias, prog_args.explainer_suffix))

    # Determine explainer mode
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # Trained data stored in computational graph dictionary
    cg_dict = model_dict['cg']
    input_dim = cg_dict['feat'].shape[2]
    num_classes = cg_dict['pred'].shape[2]
    print("\nLoaded model from subdirectory \"{}\" ...".format(
        prog_args.ckptdir))
    print("input dim :", input_dim, "; num classes :", num_classes)
    print("Labels of retrieved data :\n", cg_dict['label'])

    print(
        '------------------------------------------------------------------------------------'
    )
    print("Multigraph class :", prog_args.multigraph_class)
    print("Graph Index :", prog_args.graph_idx)
    print("Explainer graph mode :", graph_mode)
    print("Input dimension :", input_dim)
    print("Hidden dimension :", prog_args.hidden_dim)
    print("Output dimension :", prog_args.output_dim)
    print("Number of classes :", num_classes)
    print("Number of GCN layers :", prog_args.num_gc_layers)
    print("Batch Normalization :", prog_args.bn)

    model = models.GcnEncoderNode(input_dim=input_dim,
                                  hidden_dim=prog_args.hidden_dim,
                                  embedding_dim=prog_args.output_dim,
                                  label_dim=num_classes,
                                  num_layers=prog_args.num_gc_layers,
                                  bn=prog_args.bn,
                                  args=prog_args)

    print("\nGcnEncoderNode model :\n", model)

    # load state_dict (obtained by model.state_dict() when saving checkpoint)
    # Loading Model for Inference
    print("Model checked result :",
          model.load_state_dict(model_dict['model_state']))
    print(
        '------------------------------------------------------------------------------------\n'
    )

    # Explaining single node prediction
    print('Explaining single default node :', prog_args.explain_node)

    # The number of epochs used for explanation training is much smaller than the 1K epochs used for node label
    # trainings and predictions in the GCN.  The former is trained only based on the k-hop labels which depends
    # on the number GCN layers (at a smaller scale, so the number of epochs can be lower without reducing the
    # accuracy). Whereas, the latter will affect the node predictions and thus, it will affect the accuracy of
    # the node explanations.

    print('GNN Explainer is trained based on {} epochs.'.format(
        prog_args.num_epochs))
    print("Writer :", prog_args.writer)

    # Create explainer
    explainer = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=prog_args.writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    if prog_args.explain_node is not None:
        # Returned masked adjacency, edges and features of the subgraph
        masked_adj, masked_edges, masked_features = explainer.explain(
            prog_args.explain_node, unconstrained=False)

        print("Returned masked adjacency matrix :\n", masked_adj)
        print("Returned masked edges matrix :\n", masked_edges)
        print("Returned masked features matrix :\n", masked_features)
    else:
        print("Please provide node for explanation.")
Ejemplo n.º 2
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()
    # load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Create explainer
    explainer = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    # TODO: API should definitely be cleaner
    # Let's define exactly which modes we support
    # We could even move each mode to a different method (even file)
    if prog_args.explain_node is not None:
        explainer.explain(prog_args.explain_node, unconstrained=False)
    elif graph_mode:
        if prog_args.multigraph_class >= 0:
            print(cg_dict["label"])
            # only run for graphs with label specified by multigraph_class
            labels = cg_dict["label"].numpy()
            graph_indices = []
            for i, l in enumerate(labels):
                if l == prog_args.multigraph_class:
                    graph_indices.append(i)
                if len(graph_indices) > 30:
                    break
            print(
                "Graph indices for label ",
                prog_args.multigraph_class,
                " : ",
                graph_indices,
            )
            explainer.explain_graphs(graph_indices=graph_indices)

        elif prog_args.graph_idx == -1:
            # just run for a customized set of indices
            explainer.explain_graphs(graph_indices=[1, 2, 3, 4])
        else:
            explainer.explain(
                node_idx=0,
                graph_idx=prog_args.graph_idx,
                graph_mode=True,
                unconstrained=False,
            )
            io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap")
    else:
        if prog_args.multinode_class >= 0:
            print(cg_dict["label"])
            # only run for nodes with label specified by multinode_class
            labels = cg_dict["label"][0]  # already numpy matrix

            node_indices = []
            for i, l in enumerate(labels):
                if len(node_indices) > 4:
                    break
                if l == prog_args.multinode_class:
                    node_indices.append(i)
            print(
                "Node indices for label ",
                prog_args.multinode_class,
                " : ",
                node_indices,
            )
            explainer.explain_nodes(node_indices, prog_args)

        else:
            # explain a set of nodes
            masked_adj = explainer.explain_nodes_gnn_stats(
                range(400, 700, 5), prog_args)
Ejemplo n.º 3
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load data and a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode (node classif)
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()

    # Load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Convertion data required to get correct model output for GraphSHAP
    adj = torch.tensor(cg_dict["adj"], dtype=torch.float)
    x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float)
    if prog_args.gpu:
        y_pred, att_adj = model(x.cuda(), adj.cuda())
    else:
        y_pred, att_adj = model(x, adj)

    # Transform their data into our format
    data = transform_data(adj, x, cg_dict["label"][0].tolist())

    # Generate test nodes
    # Use only these specific nodes as they are the ones added manually, part of the defined shapes
    # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx'])
    k = 4  # number of nodes for the shape introduced (house, cycle)
    K = 0
    if prog_args.dataset == 'syn1':
        node_indices = list(range(400, 410, 5))
    elif prog_args.dataset == 'syn2':
        node_indices = list(range(400, 405, 5)) + list(range(1100, 1105, 5))
    elif prog_args.dataset == 'syn4':
        node_indices = list(range(511, 523, 6))
        if prog_args.hops == 3:
            k = 5
        else:
            K = 5
    elif prog_args.dataset == 'syn5':
        node_indices = list(range(511, 529, 9))
        if prog_args.hops == 3:
            k = 7
            K = 8
        else:
            k = 5
            K = 8

    # GraphSHAP explainer
    # graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset, prog_args.gpu)

    # Run GNN Explainer and retrieve produced explanations

    gnne = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    ### GNNE
    # Explain a set of nodes - accuracy on edges this time
    t = time.time()
    gnne_edge_accuracy, gnne_auc, gnne_node_accuracy, important_nodes_gnne =\
        gnne.explain_nodes_gnn_stats(
            node_indices, prog_args
        )
    e = time.time()
    print('Time: ', e - t)
Ejemplo n.º 4
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load data and a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode (node classif)
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()

    # Load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Convertion data required to get correct model output for GraphSHAP
    adj = torch.tensor(cg_dict["adj"], dtype=torch.float)
    x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float)
    if prog_args.gpu:
        y_pred, att_adj = model(x.cuda(), adj.cuda())
    else:
        y_pred, att_adj = model(x, adj)

    # Transform their data into our format
    data = transform_data(adj, x, cg_dict["label"][0].tolist())

    # Generate test nodes
    # Use only these specific nodes as they are the ones added manually, part of the defined shapes
    # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx'])
    k = 4  # number of nodes for the shape introduced (house, cycle)
    K = 0
    if prog_args.dataset == 'syn1':
        node_indices = list(range(400, 450, 5))
    elif prog_args.dataset == 'syn2':
        node_indices = list(range(400, 425, 5)) + list(range(1100, 1125, 5))
    elif prog_args.dataset == 'syn4':
        node_indices = list(range(511, 571, 6))
        if prog_args.hops == 3:
            k = 5
        else:
            K = 5
    elif prog_args.dataset == 'syn5':
        node_indices = list(range(511, 601, 9))
        if prog_args.hops == 3:
            k = 8
        else:
            k = 5
            K = 8

    # GraphSHAP explainer
    graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset,
                          prog_args.gpu)

    # Run GNN Explainer and retrieve produced explanations
    gnne = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    #if prog_args.explain_node is not None:
    # _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy = \
    #     gnne.explain_nodes_gnn_stats(
    #         node_indices, prog_args
    # )
    # elif graph_mode:
    #     # Graph explanation
    #     gnne_expl = gnne.explain_graphs([1])[0]

    # GraphSHAP - assess accuracy of explanations
    # Loop over test nodes
    accuracy = []
    feat_accuracy = []
    for node_idx in node_indices:
        start = time.time()
        graphshap_explanations = graphshap.explain(
            [node_idx],
            prog_args.hops,
            prog_args.num_samples,
            prog_args.info,
            prog_args.multiclass,
            prog_args.fullempty,
            prog_args.S,
            prog_args.hv,
            prog_args.feat,
            prog_args.coal,
            prog_args.g,
            prog_args.regu,
        )[0]

        end = time.time()
        print('GS Time:', end - start)

        # Predicted class
        pred_val, predicted_class = y_pred[0, node_idx, :].max(dim=0)

        # Keep only node explanations
        # ,predicted_class]
        graphshap_node_explanations = graphshap_explanations[graphshap.F:]

        # Derive ground truth from graph structure
        ground_truth = list(range(node_idx + 1, node_idx + max(k, K) + 1))

        # Retrieve top k elements indices form graphshap_node_explanations
        if graphshap.neighbours.shape[0] > k:
            i = 0
            val, indices = torch.topk(
                torch.tensor(graphshap_node_explanations.T), k + 1)
            # could weight importance based on val
            for node in graphshap.neighbours[indices]:
                if node.item() in ground_truth:
                    i += 1
            # Sort of accruacy metric
            accuracy.append(i / k)

            print('There are {} from targeted shape among most imp. nodes'.
                  format(i))

        # Look at importance distribution among features
        # Identify most important features and check if it corresponds to truly imp ones
        if prog_args.dataset == 'syn2':
            # ,predicted_class]
            graphshap_feat_explanations = graphshap_explanations[:graphshap.F]
            print('Feature importance graphshap',
                  graphshap_feat_explanations.T)
            if np.argsort(graphshap_feat_explanations)[-1] == 0:
                feat_accuracy.append(1)
            else:
                feat_accuracy.append(0)

    # Metric for graphshap
    final_accuracy = sum(accuracy) / len(accuracy)

    ### GNNE
    # Explain a set of nodes - accuracy on edges this time
    _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy =\
        gnne.explain_nodes_gnn_stats(
            node_indices, prog_args
        )

    ### GRAD benchmark
    #  MetricS to assess quality of predictionsx
    """
    _, grad_edge_accuracy, grad_auc, grad_node_accuracy =\
            gnne.explain_nodes_gnn_stats(
                node_indices, prog_args, model="grad")
    """
    grad_edge_accuracy = 0
    grad_node_accuracy = 0

    ### GAT
    # Nothing for now - implem a GAT on the side and look at weights coef

    ### Results
    print(
        'Accuracy for GraphSHAP is {:.2f} vs {:.2f},{:.2f} for GNNE vs {:.2f},{:.2f} for GRAD'
        .format(final_accuracy, np.mean(gnne_edge_accuracy),
                np.mean(gnne_node_accuracy), np.mean(grad_edge_accuracy),
                np.mean(grad_node_accuracy)))
    if prog_args.dataset == 'syn2':
        print('Most important feature was found in {:.2f}% of the case'.format(
            100 * np.mean(feat_accuracy)))

    print('GNNE_auc is:', gnne_auc)