def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
                                    feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    if feat_mask_type == 'scalar':
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            edge_y=edge_y, threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
    if not allow_edge_mask:
        assert edge_mask[:8].tolist() == [1.] * 8
        assert edge_mask[8:].tolist() == [0.] * 6
def find_noise_feats_by_GNNExplainer(model, data, args):
    explainer = GNNExplainer(model,
                             epochs=args.masks_epochs,
                             lr=args.masks_lr,
                             num_hops=args.hop,
                             log=False)

    node_indices = extract_test_nodes(data, args.test_samples)

    num_noise_feats = []
    for node_idx in tqdm(node_indices, desc='explain node', leave=False):
        node_feat_mask, edge_mask = explainer.explain_node(
            node_idx, data.x, data.edge_index)
        node_feat_mask = node_feat_mask.detach().cpu().numpy()

        feat_indices = node_feat_mask.argsort()[-args.K:]
        feat_indices = [
            idx for idx in feat_indices
            if node_feat_mask[idx] > args.masks_threshold
        ]

        num_noise_feat = sum(idx >= INPUT_DIM[args.dataset]
                             for idx in feat_indices)
        num_noise_feats.append(num_noise_feat)

    return num_noise_feats
def test_gnn_explainer_with_existing_self_loops(model, return_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [0, 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
    y = torch.randint(0, 6, (8, ), dtype=torch.long)

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1

    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)
Example #6
0
        super(Net, self).__init__()
        self.lin = Sequential(Linear(10, 10))
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index = data.x, data.edge_index

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    log_logits = model(x, edge_index)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

explainer = GNNExplainer(model, epochs=200)
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
plt = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y)
# plt.show()
    return train_acc, test_acc


for epoch in range(1, 2001):
    loss = train()
    if epoch % 200 == 0:
        train_acc, test_acc = test()
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Test: {test_acc:.4f}')

model.eval()
targets, preds = [], []
expl = GNNExplainer(model, epochs=300, return_type='raw', log=False)

# Explanation ROC AUC over all test nodes:
self_loop_mask = data.edge_index[0] != data.edge_index[1]
for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()):
    _, expl_edge_mask = expl.explain_node(node_idx,
                                          data.x,
                                          data.edge_index,
                                          edge_weight=data.edge_weight)
    subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index)
    expl_edge_mask = expl_edge_mask[self_loop_mask]
    subgraph_edge_mask = subgraph[3][self_loop_mask]
    targets.append(data.edge_label[subgraph_edge_mask].cpu())
    preds.append(expl_edge_mask[subgraph_edge_mask].cpu())

auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
print(f'Mean ROC AUC: {auc:.4f}')
Example #8
0
print(f'Contains self-loops: {single_graph.contains_self_loops()}')
print(f'Is undirected: {single_graph.is_undirected()}')

single_graph = single_graph.to(device)
x, edge_index = single_graph.x, single_graph.edge_index

# explain a node
explainer = GNNExplainer(model, epochs=200)
# need to do the union of all the nodes to get full graph explanation...
feat_mask_all = torch.zeros(dataset.num_features)
edge_mask_all = torch.zeros(single_graph.num_edges)

node_idx = 0
node_feat_mask, edge_mask = explainer.explain_node(node_idx,
                                                   x,
                                                   edge_index,
                                                   batch=torch.zeros(
                                                       single_graph.num_nodes,
                                                       dtype=torch.int64))

ax, G = explainer.visualize_subgraph(node_idx,
                                     edge_index,
                                     edge_mask,
                                     threshold=0.5)  #, y=single_graph.y)
plt.show()

#
# """
# Cora - node classification
# """

# import os.path as osp