Example #1
0
    def explain(
        self, node_idx, graph_idx=0, graph_mode=False, unconstrained=False, model="exp"
    ):
        """Explain a single node prediction
        """
        # index of the query node in the new adj
        if graph_mode:
            node_idx_new = node_idx
            sub_adj = self.adj[graph_idx]
            sub_feat = self.feat[graph_idx, :]
            sub_label = self.label[graph_idx]
            neighbors = np.asarray(range(self.adj.shape[0]))
        else:
            print("node label: ", self.label[graph_idx][node_idx])
            node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood(
                node_idx, graph_idx
            )
            print("neigh graph idx: ", node_idx, node_idx_new)
            sub_label = np.expand_dims(sub_label, axis=0)
        
        sub_adj = np.expand_dims(sub_adj, axis=0)
        sub_feat = np.expand_dims(sub_feat, axis=0)

        adj   = torch.tensor(sub_adj, dtype=torch.float)
        x     = torch.tensor(sub_feat, requires_grad=True, dtype=torch.float)
        label = torch.tensor(sub_label, dtype=torch.long)

        if self.graph_mode:
            pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
            print("Graph predicted label: ", pred_label)
        else:
            pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)
            print("Node predicted label: ", pred_label[node_idx_new])

        explainer = ExplainModule(
            adj=adj,
            x=x,
            model=self.model,
            label=label,
            args=self.args,
            writer=self.writer,
            graph_idx=self.graph_idx,
            graph_mode=self.graph_mode,
        )
        if self.args.gpu:
            explainer = explainer.cuda()

        self.model.eval()


        # gradient baseline
        if model == "grad":
            explainer.zero_grad()
            # pdb.set_trace()
            adj_grad = torch.abs(
                explainer.adj_feat_grad(node_idx_new, pred_label[node_idx_new])[0]
            )[graph_idx]
            masked_adj = adj_grad + adj_grad.t()
            masked_adj = nn.functional.sigmoid(masked_adj)
            masked_adj = masked_adj.cpu().detach().numpy() * sub_adj.squeeze()
        else:
            explainer.train()
            begin_time = time.time()
            for epoch in range(self.args.num_epochs):
                explainer.zero_grad()
                explainer.optimizer.zero_grad()
                ypred, adj_atts = explainer(node_idx_new, unconstrained=unconstrained)
                loss = explainer.loss(ypred, pred_label, node_idx_new, epoch)
                loss.backward()

                explainer.optimizer.step()
                if explainer.scheduler is not None:
                    explainer.scheduler.step()

                mask_density = explainer.mask_density()
                if self.print_training:
                    print(
                        "epoch: ",
                        epoch,
                        "; loss: ",
                        loss.item(),
                        "; mask density: ",
                        mask_density.item(),
                        "; pred: ",
                        ypred,
                    )
                single_subgraph_label = sub_label.squeeze()

                if self.writer is not None:
                    self.writer.add_scalar("mask/density", mask_density, epoch)
                    self.writer.add_scalar(
                        "optimization/lr",
                        explainer.optimizer.param_groups[0]["lr"],
                        epoch,
                    )
                    if epoch % 25 == 0:
                        explainer.log_mask(epoch)
                        explainer.log_masked_adj(
                            node_idx_new, epoch, label=single_subgraph_label
                        )
                        explainer.log_adj_grad(
                            node_idx_new, pred_label, epoch, label=single_subgraph_label
                        )

                    if epoch == 0:
                        if self.model.att:
                            # explain node
                            print("adj att size: ", adj_atts.size())
                            adj_att = torch.sum(adj_atts[0], dim=2)
                            # adj_att = adj_att[neighbors][:, neighbors]
                            node_adj_att = adj_att * adj.float().cuda()
                            io_utils.log_matrix(
                                self.writer, node_adj_att[0], "att/matrix", epoch
                            )
                            node_adj_att = node_adj_att[0].cpu().detach().numpy()
                            G = io_utils.denoise_graph(
                                node_adj_att,
                                node_idx_new,
                                threshold=3.8,  # threshold_num=20,
                                max_component=True,
                            )
                            io_utils.log_graph(
                                self.writer,
                                G,
                                name="att/graph",
                                identify_self=not self.graph_mode,
                                nodecolor="label",
                                edge_vmax=None,
                                args=self.args,
                            )
                if model != "exp":
                    break

            print("finished training in ", time.time() - begin_time)
            if model == "exp":
                masked_adj = (
                    explainer.masked_adj[0].cpu().detach().numpy() * sub_adj.squeeze()
                )
            else:
                adj_atts = nn.functional.sigmoid(adj_atts).squeeze()
                masked_adj = adj_atts.cpu().detach().numpy() * sub_adj.squeeze()

        fname = 'masked_adj_' + io_utils.gen_explainer_prefix(self.args) + (
                'node_idx_'+str(node_idx)+'graph_idx_'+str(self.graph_idx)+'.npy')
        with open(os.path.join(self.args.logdir, fname), 'wb') as outfile:
            np.save(outfile, np.asarray(masked_adj.copy()))
            print("Saved adjacency matrix to ", fname)
        return masked_adj
Example #2
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load data and a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode (node classif)
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()

    # Load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Convertion data required to get correct model output for GraphSHAP
    adj = torch.tensor(cg_dict["adj"], dtype=torch.float)
    x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float)
    if prog_args.gpu:
        y_pred, att_adj = model(x.cuda(), adj.cuda())
    else:
        y_pred, att_adj = model(x, adj)

    # Transform their data into our format
    data = transform_data(adj, x, cg_dict["label"][0].tolist())

    # Generate test nodes
    # Use only these specific nodes as they are the ones added manually, part of the defined shapes
    # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx'])
    k = 4  # number of nodes for the shape introduced (house, cycle)
    K = 0
    if prog_args.dataset == 'syn1':
        node_indices = list(range(400, 410, 5))
    elif prog_args.dataset == 'syn2':
        node_indices = list(range(400, 405, 5)) + list(range(1100, 1105, 5))
    elif prog_args.dataset == 'syn4':
        node_indices = list(range(511, 523, 6))
        if prog_args.hops == 3:
            k = 5
        else:
            K = 5
    elif prog_args.dataset == 'syn5':
        node_indices = list(range(511, 529, 9))
        if prog_args.hops == 3:
            k = 7
            K = 8
        else:
            k = 5
            K = 8

    # GraphSHAP explainer
    # graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset, prog_args.gpu)

    # Run GNN Explainer and retrieve produced explanations

    gnne = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    ### GNNE
    # Explain a set of nodes - accuracy on edges this time
    t = time.time()
    gnne_edge_accuracy, gnne_auc, gnne_node_accuracy, important_nodes_gnne =\
        gnne.explain_nodes_gnn_stats(
            node_indices, prog_args
        )
    e = time.time()
    print('Time: ', e - t)
Example #3
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load data and a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode (node classif)
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()

    # Load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Convertion data required to get correct model output for GraphSHAP
    adj = torch.tensor(cg_dict["adj"], dtype=torch.float)
    x = torch.tensor(cg_dict["feat"], requires_grad=True, dtype=torch.float)
    if prog_args.gpu:
        y_pred, att_adj = model(x.cuda(), adj.cuda())
    else:
        y_pred, att_adj = model(x, adj)

    # Transform their data into our format
    data = transform_data(adj, x, cg_dict["label"][0].tolist())

    # Generate test nodes
    # Use only these specific nodes as they are the ones added manually, part of the defined shapes
    # node_indices = extract_test_nodes(data, num_samples=10, cg_dict['train_idx'])
    k = 4  # number of nodes for the shape introduced (house, cycle)
    K = 0
    if prog_args.dataset == 'syn1':
        node_indices = list(range(400, 450, 5))
    elif prog_args.dataset == 'syn2':
        node_indices = list(range(400, 425, 5)) + list(range(1100, 1125, 5))
    elif prog_args.dataset == 'syn4':
        node_indices = list(range(511, 571, 6))
        if prog_args.hops == 3:
            k = 5
        else:
            K = 5
    elif prog_args.dataset == 'syn5':
        node_indices = list(range(511, 601, 9))
        if prog_args.hops == 3:
            k = 8
        else:
            k = 5
            K = 8

    # GraphSHAP explainer
    graphshap = GraphSHAP(data, model, adj, writer, prog_args.dataset,
                          prog_args.gpu)

    # Run GNN Explainer and retrieve produced explanations
    gnne = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    #if prog_args.explain_node is not None:
    # _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy = \
    #     gnne.explain_nodes_gnn_stats(
    #         node_indices, prog_args
    # )
    # elif graph_mode:
    #     # Graph explanation
    #     gnne_expl = gnne.explain_graphs([1])[0]

    # GraphSHAP - assess accuracy of explanations
    # Loop over test nodes
    accuracy = []
    feat_accuracy = []
    for node_idx in node_indices:
        start = time.time()
        graphshap_explanations = graphshap.explain(
            [node_idx],
            prog_args.hops,
            prog_args.num_samples,
            prog_args.info,
            prog_args.multiclass,
            prog_args.fullempty,
            prog_args.S,
            prog_args.hv,
            prog_args.feat,
            prog_args.coal,
            prog_args.g,
            prog_args.regu,
        )[0]

        end = time.time()
        print('GS Time:', end - start)

        # Predicted class
        pred_val, predicted_class = y_pred[0, node_idx, :].max(dim=0)

        # Keep only node explanations
        # ,predicted_class]
        graphshap_node_explanations = graphshap_explanations[graphshap.F:]

        # Derive ground truth from graph structure
        ground_truth = list(range(node_idx + 1, node_idx + max(k, K) + 1))

        # Retrieve top k elements indices form graphshap_node_explanations
        if graphshap.neighbours.shape[0] > k:
            i = 0
            val, indices = torch.topk(
                torch.tensor(graphshap_node_explanations.T), k + 1)
            # could weight importance based on val
            for node in graphshap.neighbours[indices]:
                if node.item() in ground_truth:
                    i += 1
            # Sort of accruacy metric
            accuracy.append(i / k)

            print('There are {} from targeted shape among most imp. nodes'.
                  format(i))

        # Look at importance distribution among features
        # Identify most important features and check if it corresponds to truly imp ones
        if prog_args.dataset == 'syn2':
            # ,predicted_class]
            graphshap_feat_explanations = graphshap_explanations[:graphshap.F]
            print('Feature importance graphshap',
                  graphshap_feat_explanations.T)
            if np.argsort(graphshap_feat_explanations)[-1] == 0:
                feat_accuracy.append(1)
            else:
                feat_accuracy.append(0)

    # Metric for graphshap
    final_accuracy = sum(accuracy) / len(accuracy)

    ### GNNE
    # Explain a set of nodes - accuracy on edges this time
    _, gnne_edge_accuracy, gnne_auc, gnne_node_accuracy =\
        gnne.explain_nodes_gnn_stats(
            node_indices, prog_args
        )

    ### GRAD benchmark
    #  MetricS to assess quality of predictionsx
    """
    _, grad_edge_accuracy, grad_auc, grad_node_accuracy =\
            gnne.explain_nodes_gnn_stats(
                node_indices, prog_args, model="grad")
    """
    grad_edge_accuracy = 0
    grad_node_accuracy = 0

    ### GAT
    # Nothing for now - implem a GAT on the side and look at weights coef

    ### Results
    print(
        'Accuracy for GraphSHAP is {:.2f} vs {:.2f},{:.2f} for GNNE vs {:.2f},{:.2f} for GRAD'
        .format(final_accuracy, np.mean(gnne_edge_accuracy),
                np.mean(gnne_node_accuracy), np.mean(grad_edge_accuracy),
                np.mean(grad_node_accuracy)))
    if prog_args.dataset == 'syn2':
        print('Most important feature was found in {:.2f}% of the case'.format(
            100 * np.mean(feat_accuracy)))

    print('GNNE_auc is:', gnne_auc)
Example #4
0
def main():
    # Load a configuration
    prog_args = arg_parse()

    if prog_args.gpu:
        os.environ["CUDA_VISIBLE_DEVICES"] = prog_args.cuda
        print("CUDA", prog_args.cuda)
    else:
        print("Using CPU")

    # Configure the logging directory
    if prog_args.writer:
        path = os.path.join(prog_args.logdir,
                            io_utils.gen_explainer_prefix(prog_args))
        if os.path.isdir(path) and prog_args.clean_log:
            print('Removing existing log dir: ', path)
            if not input(
                    "Are you sure you want to remove this directory? (y/n): "
            ).lower().strip()[:1] == "y":
                sys.exit(1)
            shutil.rmtree(path)
        writer = SummaryWriter(path)
    else:
        writer = None

    # Load a model checkpoint
    ckpt = io_utils.load_ckpt(prog_args)
    cg_dict = ckpt["cg"]  # get computation graph
    input_dim = cg_dict["feat"].shape[2]
    num_classes = cg_dict["pred"].shape[2]
    print("Loaded model from {}".format(prog_args.ckptdir))
    print("input dim: ", input_dim, "; num classes: ", num_classes)

    # Determine explainer mode
    graph_mode = (prog_args.graph_mode or prog_args.multigraph_class >= 0
                  or prog_args.graph_idx >= 0)

    # build model
    print("Method: ", prog_args.method)
    if graph_mode:
        # Explain Graph prediction
        model = models.GcnEncoderGraph(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    else:
        if prog_args.dataset == "ppi_essential":
            # class weight in CE loss for handling imbalanced label classes
            prog_args.loss_weight = torch.tensor([1.0, 5.0],
                                                 dtype=torch.float).cuda()
        # Explain Node prediction
        model = models.GcnEncoderNode(
            input_dim=input_dim,
            hidden_dim=prog_args.hidden_dim,
            embedding_dim=prog_args.output_dim,
            label_dim=num_classes,
            num_layers=prog_args.num_gc_layers,
            bn=prog_args.bn,
            args=prog_args,
        )
    if prog_args.gpu:
        model = model.cuda()
    # load state_dict (obtained by model.state_dict() when saving checkpoint)
    model.load_state_dict(ckpt["model_state"])

    # Create explainer
    explainer = explain.Explainer(
        model=model,
        adj=cg_dict["adj"],
        feat=cg_dict["feat"],
        label=cg_dict["label"],
        pred=cg_dict["pred"],
        train_idx=cg_dict["train_idx"],
        args=prog_args,
        writer=writer,
        print_training=True,
        graph_mode=graph_mode,
        graph_idx=prog_args.graph_idx,
    )

    # TODO: API should definitely be cleaner
    # Let's define exactly which modes we support
    # We could even move each mode to a different method (even file)
    if prog_args.explain_node is not None:
        explainer.explain(prog_args.explain_node, unconstrained=False)
    elif graph_mode:
        if prog_args.multigraph_class >= 0:
            print(cg_dict["label"])
            # only run for graphs with label specified by multigraph_class
            labels = cg_dict["label"].numpy()
            graph_indices = []
            for i, l in enumerate(labels):
                if l == prog_args.multigraph_class:
                    graph_indices.append(i)
                if len(graph_indices) > 30:
                    break
            print(
                "Graph indices for label ",
                prog_args.multigraph_class,
                " : ",
                graph_indices,
            )
            explainer.explain_graphs(graph_indices=graph_indices)

        elif prog_args.graph_idx == -1:
            # just run for a customized set of indices
            explainer.explain_graphs(graph_indices=[1, 2, 3, 4])
        else:
            explainer.explain(
                node_idx=0,
                graph_idx=prog_args.graph_idx,
                graph_mode=True,
                unconstrained=False,
            )
            io_utils.plot_cmap_tb(writer, "tab20", 20, "tab20_cmap")
    else:
        if prog_args.multinode_class >= 0:
            print(cg_dict["label"])
            # only run for nodes with label specified by multinode_class
            labels = cg_dict["label"][0]  # already numpy matrix

            node_indices = []
            for i, l in enumerate(labels):
                if len(node_indices) > 4:
                    break
                if l == prog_args.multinode_class:
                    node_indices.append(i)
            print(
                "Node indices for label ",
                prog_args.multinode_class,
                " : ",
                node_indices,
            )
            explainer.explain_nodes(node_indices, prog_args)

        else:
            # explain a set of nodes
            masked_adj = explainer.explain_nodes_gnn_stats(
                range(400, 700, 5), prog_args)
Example #5
0
    def explain(self,
                node_idx,
                graph_idx=0,
                graph_mode=False,
                unconstrained=False,
                exp_model="exp"):
        print('************** Explaining node : {} **************'.format(
            node_idx))
        print('The label for graph index {} and node index {} : {}'.format(
            graph_idx, node_idx, self.label[graph_idx][node_idx]))
        print("Labels of all the nodes :\n", self.label)

        # Adjacency matrix of entire graph
        print("Shape of retrieved neighborhoods :", self.neighborhoods.shape)
        print("No. of neighborhoods :",
              len(self.neighborhoods[graph_idx][node_idx]))
        print(
            'List of neighborhoods for explaining node {} :'.format(node_idx))
        print(self.neighborhoods[graph_idx][node_idx])

        # index of the query node in the new adj
        if graph_mode:
            node_idx_new = node_idx
            sub_adj = self.adj[graph_idx]
            sub_feat = self.feat[graph_idx, :]
            sub_label = self.label[graph_idx]
            neighbors = np.asarray(range(self.adj.shape[0]))
        else:
            print("Ground truth, node label :",
                  self.label[graph_idx][node_idx])
            # Computational graph :
            # Extracting subgraph adjacency matrix, subgraph features, subgraph labels and the nodes neighbours
            node_idx_new, sub_adj, sub_feat, sub_label, neighbors = self.extract_neighborhood(
                node_idx, graph_idx)
            sub_label = np.expand_dims(sub_label, axis=0)

        sub_adj = np.expand_dims(sub_adj, axis=0)
        sub_feat = np.expand_dims(sub_feat, axis=0)

        print("Neighbouring graph index for node " + str(node_idx) +
              " with new node index " + str(node_idx_new))
        #print("Expand dimension of Subgraph adjacency :\n", sub_adj)
        #print("Expand dimension of Subgraph features :\n", sub_feat)
        print("Expand dimension of Subgraph label :\n", sub_label)

        # All the nodes in the graph (eg. indexes from 0 to 34)
        print("Subgraph neighbors :\n", neighbors)
        tensor_adj = torch.tensor(sub_adj, dtype=torch.float)
        tensor_x = torch.tensor(sub_feat,
                                requires_grad=True,
                                dtype=torch.float)
        tensor_label = torch.tensor(sub_label, dtype=torch.long)

        if self.graph_mode:
            pred_label = np.argmax(self.pred[0][graph_idx], axis=0)
            print("Graph predicted label: ", pred_label)
        else:
            pred_label = np.argmax(self.pred[graph_idx][neighbors], axis=1)
            print("Neighbours of predicted node labels :",
                  self.pred[graph_idx][neighbors])
            print(
                'Predicted labels for all {} neighbours (includes itself) :\n{}'
                .format(len(pred_label), pred_label))
            print('Predicted label for node {} : {}'.format(
                node_idx, pred_label[node_idx_new]))

        # Have to use the tensor version of adj for Tensor computation
        explainerMod = ExplainModule(
            adj=tensor_adj,  # adj
            x=tensor_x,  # x
            model=self.model,  # model
            label=tensor_label,  # label
            args=self.args,  # prog_args
            writer=self.writer,  # None
            graph_idx=self.graph_idx,  # graph_idx
            graph_mode=self.graph_mode  # graph_mode
        )

        self.model.eval()
        explainerMod.train()
        begin_time = time.time()

        # prog_args.num_epochs
        for epoch in range(self.args.num_epochs):
            explainerMod.zero_grad()
            explainerMod.optimizer.zero_grad()

            # node_idx_new is passed to explainerMod.forward to training with the new index
            ypred, adj_atts = explainerMod(node_idx_new,
                                           unconstrained=unconstrained)
            loss = explainerMod.loss(ypred, pred_label, node_idx_new, epoch)
            loss.backward()

            explainerMod.optimizer.step()
            mask_density = explainerMod.mask_density()

            print("epoch: ", epoch, "; loss: ", loss.item(),
                  "; mask density: ", mask_density.item(), "; pred: ", ypred)
            print(
                "------------------------------------------------------------------"
            )

            if exp_model != "exp":
                break

        print("\n--------------------------------------------")
        print("Final ypred after training : ", ypred)
        print("pred_label : ", pred_label)
        print("node_idx_new : ", node_idx_new)

        print("Completed training in ", time.time() - begin_time)

        if exp_model == "exp":
            masked_adj = (explainerMod.masked_adj[0].cpu().detach().numpy() *
                          sub_adj.squeeze())

            # Added for plotting node explanation subgraph
            # explainerMod.mask.shape, masked_edges.shape
            masked_edges = explainerMod.mask.cpu().detach().numpy()
            masked_features = explainerMod.feat_mask.cpu().detach().numpy()
            # explainerMod.feat_mask.shape, masked_features.shape

            ypred_detach = ypred.cpu().detach().numpy()
            ypred_node = np.argmax(ypred_detach, axis=0)  # labels

            # ypred = tensor([0.0119, 0.6456, 0.3307, 0.0118]
            print('Detach ypred : {} and Argmax node : {}'.format(
                ypred_detach, ypred_node))

        # Trained masked, edges and features adjacency matrices
        print("Shape of masked adjacency matrix : ", masked_adj.shape)
        print("The masked adjacency matrix at index [0] :\n", masked_adj[0])
        print("Shape of masked edges matrix : ", masked_edges.shape)
        print("The masked edges adjacency matrix at index [0] :\n",
              masked_edges[0])
        print("Shape of masked features matrix : ", masked_features.shape)
        print("The masked features adjacency matrix at index [0] :\n",
              masked_features[0])

        fname = 'masked_adj_' + io_utils.gen_explainer_prefix(
            self.args) + ('_node_idx_' + str(node_idx) + '_graph_idx_' +
                          str(self.graph_idx) + '.npy')
        with open(os.path.join(self.args.logdir, fname), 'wb') as outfile:
            np.save(outfile, np.asarray(masked_adj.copy()))
            print("Saved adjacency matrix to \"" + fname + "\".")

        # PlotSubGraph (sub_edge_index not used)
        self.PlotSubGraph(masked_adj,
                          masked_edges,
                          node_idx_new,
                          node_idx,
                          feats=sub_feat.squeeze(),
                          labels=tensor_label.cpu().detach().numpy().squeeze(),
                          threshold_num=12,
                          adj_mode=True)

        # Shape of masked adjacency matrix : (27, 27)
        # Shape of masked edges matrix : (27, 27)
        # Shape of masked features matrix : (10,)
        return masked_adj, masked_edges, masked_features