def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
                                    feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    if feat_mask_type == 'scalar':
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            edge_y=edge_y, threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
    if not allow_edge_mask:
        assert edge_mask[:8].tolist() == [1.] * 8
        assert edge_mask[8:].tolist() == [0.] * 6
def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask,
                                     feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    if feat_mask_type == 'scalar':
        pass
        _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask,
                                            y=torch.tensor(2), threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask,
                                            edge_y=edge_y, y=torch.tensor(2),
                                            threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_graph(model):
    explainer = GNNExplainer(model, log=False)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    _, _ = explainer.visualize_subgraph(-1,
                                        edge_index,
                                        edge_mask,
                                        y=torch.tensor(2),
                                        threshold=0.8)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
    y = torch.randint(0, 6, (8, ), dtype=torch.long)

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1

    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)
captum_model = to_captum(model, mask_type='edge', output_idx=output_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0),
                            target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

# Visualize absolute values of attributions with GNNExplainer visualizer
explainer = GNNExplainer(model)  # TODO: Change to general Explainer visualizer
ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge)
plt.show()

# Node explainability
# ===================

captum_model = to_captum(model, mask_type='node', output_idx=output_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(data.x.unsqueeze(0),
                            target=target,
                            additional_forward_args=(data.edge_index),
                            internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_node = ig_attr_node.squeeze(0).abs().sum(dim=1)
Exemple #6
0
        super(Net, self).__init__()
        self.lin = Sequential(Linear(10, 10))
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index = data.x, data.edge_index

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    log_logits = model(x, edge_index)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

explainer = GNNExplainer(model, epochs=200)
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
plt = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y)
# plt.show()
        best_loss = test_loss if test_loss < best_loss else best_loss
        if log and epoch % 50 == 1:
            print(best_loss)
    return model, x, data.y, edge_index


def extract_subgraph(node_idx, num_hops, edge_index):
    nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops,
                                                       edge_index)
    return new_edge_index, node_idx


def run_model(edge_mask, edge_index, model, node_idx):
    edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()]
    out = model(x, edge_index_1).detach().cpu()
    return out[node_idx].numpy()


if __name__ == '__main__':
    model, x, data.y, edge_index = train_model(log=True)
    explainer = GNNExplainer(model, epochs=1000, num_hops=1)
    node_idx = 549
    node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
    scores = stats.zscore(edge_mask.cpu().numpy())
    idxs = scores > 10
    print(sum(idxs))
    ax, G = explainer.visualize_subgraph(node_idx,
                                         edge_index.T[idxs].T,
                                         edge_mask[idxs],
                                         y=data.y)
    plt.savefig('explain/node%d.png' % node_idx, dpi=300)
Exemple #8
0
# explain a node
explainer = GNNExplainer(model, epochs=200)
# need to do the union of all the nodes to get full graph explanation...
feat_mask_all = torch.zeros(dataset.num_features)
edge_mask_all = torch.zeros(single_graph.num_edges)

node_idx = 0
node_feat_mask, edge_mask = explainer.explain_node(node_idx,
                                                   x,
                                                   edge_index,
                                                   batch=torch.zeros(
                                                       single_graph.num_nodes,
                                                       dtype=torch.int64))

ax, G = explainer.visualize_subgraph(node_idx,
                                     edge_index,
                                     edge_mask,
                                     threshold=0.5)  #, y=single_graph.y)
plt.show()

#
# """
# Cora - node classification
# """

# import os.path as osp

# import torch
# import torch.nn.functional as F
# import matplotlib.pyplot as plt
# from torch_geometric.datasets import Planetoid
# import torch_geometric.transforms as T
Exemple #9
0
        loss.backward()
        optimizer.step()

        # Testing step
        model.eval()
        test_loss = F.nll_loss(log_logits[test_mask], data.y[test_mask]).item()
        best_loss = test_loss if test_loss < best_loss else best_loss
        if log and epoch % 50 == 1:
            print(best_loss)
    return model, x, data.y, edge_index

def extract_subgraph(node_idx, num_hops, edge_index):
    nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops, edge_index)
    return new_edge_index, node_idx

def run_model(edge_mask, edge_index, model, node_idx):
    edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()]
    out = model(x, edge_index_1).detach().cpu()
    return out[node_idx].numpy()

if __name__ == '__main__':
    model, x, data.y, edge_index = train_model(log=True)
    explainer = GNNExplainer(model, epochs=1000, num_hops=1)
    node_idx = 1
    node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
    node_feat_mask = node_feat_mask.to('cpu')
    edge_mask = edge_mask.to('cpu')
    y = data.y.to('cpu')
    ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=y)
    nx.readwrite.edgelist.write_edgelist(G, 'scripts/BI/random/explain/node%d.el' % node_idx, data = ['att'])
    plt.savefig('scripts/BI/random/explain/node%d.png' % node_idx, dpi=300)