def test_gnn_explainer_explain_graph(model, return_type, allow_edge_mask,
                                     feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    if feat_mask_type == 'scalar':
        pass
        _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask,
                                            y=torch.tensor(2), threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(-1, edge_index, edge_mask,
                                            edge_y=edge_y, y=torch.tensor(2),
                                            threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def find_noise_feats_by_GNNExplainer(model, data, args):
    explainer = GNNExplainer(model,
                             epochs=args.masks_epochs,
                             lr=args.masks_lr,
                             num_hops=args.hop,
                             log=False)

    node_indices = extract_test_nodes(data, args.test_samples)

    num_noise_feats = []
    for node_idx in tqdm(node_indices, desc='explain node', leave=False):
        node_feat_mask, edge_mask = explainer.explain_node(
            node_idx, data.x, data.edge_index)
        node_feat_mask = node_feat_mask.detach().cpu().numpy()

        feat_indices = node_feat_mask.argsort()[-args.K:]
        feat_indices = [
            idx for idx in feat_indices
            if node_feat_mask[idx] > args.masks_threshold
        ]

        num_noise_feat = sum(idx >= INPUT_DIM[args.dataset]
                             for idx in feat_indices)
        num_noise_feats.append(num_noise_feat)

    return num_noise_feats
def test_gnn_explainer_to_log_prob(model):
    raw_to_log = GNNExplainer(model, return_type='raw').__to_log_prob__
    prob_to_log = GNNExplainer(model, return_type='prob').__to_log_prob__
    log_to_log = GNNExplainer(model, return_type='log_prob').__to_log_prob__

    raw = torch.tensor([[1, 3.2, 6.1], [9, 9, 0.1]])
    prob = raw.softmax(dim=-1)
    log_prob = raw.log_softmax(dim=-1)

    assert torch.allclose(raw_to_log(raw), prob_to_log(prob))
    assert torch.allclose(prob_to_log(prob), log_to_log(log_prob))
def test_gnn_explainer_with_existing_self_loops(model, return_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [0, 1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
def test_graph_explainer(model):
    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

    explainer = GNNExplainer(model, log=False)

    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    assert_edgemask_clear(model)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.shape[0] == edge_index.shape[1]
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_graph(model):
    explainer = GNNExplainer(model, log=False)

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_graph(x, edge_index)
    _, _ = explainer.visualize_subgraph(-1,
                                        edge_index,
                                        edge_mask,
                                        y=torch.tensor(2),
                                        threshold=0.8)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.max() <= 1 and edge_mask.min() >= 0
def test_gnn_explainer_explain_node(model, return_type, allow_edge_mask,
                                    feat_mask_type):
    explainer = GNNExplainer(model, log=False, return_type=return_type,
                             allow_edge_mask=allow_edge_mask,
                             feat_mask_type=feat_mask_type)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    y = torch.tensor([0, 1, 1, 0, 1, 0, 1, 0])
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    if feat_mask_type == 'scalar':
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            threshold=0.8,
                                            node_alpha=node_feat_mask)
    else:
        edge_y = torch.randint(low=0, high=30, size=(edge_index.size(1), ))
        _, _ = explainer.visualize_subgraph(2, edge_index, edge_mask, y=y,
                                            edge_y=edge_y, threshold=0.8)
    if feat_mask_type == 'individual_feature':
        assert node_feat_mask.size() == x.size()
    elif feat_mask_type == 'scalar':
        assert node_feat_mask.size() == (x.size(0), )
    else:
        assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1
    if not allow_edge_mask:
        assert edge_mask[:8].tolist() == [1.] * 8
        assert edge_mask[8:].tolist() == [0.] * 6
Example #9
0
else:
    for epoch in range(1, 201):
        model.train()
        optimizer.zero_grad()
        log_logits = model(x, edge_index)
        loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        torch.save(model, model_path)
log_logists = model(x, edge_index)
y_pred_test = torch.argmax(log_logists, dim=1)[np.arange(len(U_train + U_dev), len(U_train + U_dev + U_test))]
y_pred_test = y_pred_test.detach().numpy()
mean, median, acc, _, _, _ = geo_eval(Y_test, y_pred_test, U_test, classLatMedian, classLonMedian, userLocation)
print(f"mean:{mean} median: {median} acc: {acc}")

explainer = GNNExplainer(model, epochs=200)
node_idx = 9000
print(userLocation[U[node_idx]])

node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
print("top features:\n", node_feat_mask.argsort()[-10:])
print("top edges:\n", edge_index[:, edge_mask.argsort()[-10:]])
only_topk_edges = 20
all_distances = torch.FloatTensor([get_distance(U[node_idx], u, userLocation) for u in U])
#ax, G = visualize_subgraph2(explainer, node_idx, edge_index, edge_mask, y=all_distances, threshold=None, only_topk_edges=only_topk_edges, cmap=plt.cm.cool)
#plt.savefig(f"{node_idx}-{U[node_idx]}-{only_topk_edges}.pdf")
#plt.close()
#print(data.y)

ax, G = visualize_subgraph3(explainer, node_idx, edge_index, edge_mask, y=all_distances, threshold=None, only_topk_edges=only_topk_edges, cmap=plt.cm.cool)
plt.savefig(f"{node_idx}-{U[node_idx]}-{only_topk_edges} - original.pdf")
Example #10
0
# Gather some statistics about the first graph.
print(f'Number of nodes: {single_graph.num_nodes}')
print(f'Number of edges: {single_graph.num_edges}')
print(
    f'Average node degree: {single_graph.num_edges / single_graph.num_nodes:.2f}'
)
print(f'Contains isolated nodes: {single_graph.contains_isolated_nodes()}')
print(f'Contains self-loops: {single_graph.contains_self_loops()}')
print(f'Is undirected: {single_graph.is_undirected()}')

single_graph = single_graph.to(device)
x, edge_index = single_graph.x, single_graph.edge_index

# explain a node
explainer = GNNExplainer(model, epochs=200)
# need to do the union of all the nodes to get full graph explanation...
feat_mask_all = torch.zeros(dataset.num_features)
edge_mask_all = torch.zeros(single_graph.num_edges)

node_idx = 0
node_feat_mask, edge_mask = explainer.explain_node(node_idx,
                                                   x,
                                                   edge_index,
                                                   batch=torch.zeros(
                                                       single_graph.num_nodes,
                                                       dtype=torch.int64))

ax, G = explainer.visualize_subgraph(node_idx,
                                     edge_index,
                                     edge_mask,
Example #11
0
        self.conv1 = GCNConv(dataset.num_features, 16, normalize=False)
        self.conv2 = GCNConv(16, dataset.num_classes, normalize=False)

    def forward(self, x, edge_index, edge_weight):
        x = F.relu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    log_logits = model(x, edge_index, edge_weight)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

explainer = GNNExplainer(model, epochs=200, return_type='log_prob')
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index,
                                                   edge_weight=edge_weight)
ax, G = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y)
plt.show()
Example #12
0
        best_loss = test_loss if test_loss < best_loss else best_loss
        if log and epoch % 50 == 1:
            print(best_loss)
    return model, x, data.y, edge_index


def extract_subgraph(node_idx, num_hops, edge_index):
    nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops,
                                                       edge_index)
    return new_edge_index, node_idx


def run_model(edge_mask, edge_index, model, node_idx):
    edge_index_1 = edge_index[:, torch.tensor(edge_mask).to(device).bool()]
    out = model(x, edge_index_1).detach().cpu()
    return out[node_idx].numpy()


if __name__ == '__main__':
    model, x, data.y, edge_index = train_model(log=True)
    explainer = GNNExplainer(model, epochs=1000, num_hops=1)
    node_idx = 549
    node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
    scores = stats.zscore(edge_mask.cpu().numpy())
    idxs = scores > 10
    print(sum(idxs))
    ax, G = explainer.visualize_subgraph(node_idx,
                                         edge_index.T[idxs].T,
                                         edge_mask[idxs],
                                         y=data.y)
    plt.savefig('explain/node%d.png' % node_idx, dpi=300)
Example #13
0
from sklearn.metrics import roc_auc_score

auc = 0
auc_gnn_exp = 0
pbar = range(x.shape[0])
done = 0
for n in pbar:
    try:
        k = 3
        sharp = 1e-12
        splines = 6
        explainer = BayesianExplainer(model, n, k, x, edge_index, sharp, splines)
        avgs = explainer.train(epochs=3000, lr=5, lambd=5e-11, window=500, p = 1.1, log=False)
        edge_mask = explainer.edge_mask()
        edges = explainer.edge_index_adj
        labs = edge_labels[explainer.subset, :][:, explainer.subset][edges[0, :], edges[1, :]]
        sub_idx = (labs.long().cpu().detach().numpy() == 1)
        itr_auc = roc_auc_score(labs.long().cpu().detach().numpy()[sub_idx], edge_mask.cpu().detach().numpy()[sub_idx])
        auc += itr_auc
        e_subset = explainer.edge_mask_hard
        explainer = GNNExplainer(model.to(device), epochs=1000, log=False)
        _, edge_mask = explainer.explain_node(n, x.to(device), edge_index.to(device))
        auc_gnn_exp += roc_auc_score(labs.long().cpu().detach().numpy()[sub_idx], edge_mask[e_subset].cpu().detach().numpy()[sub_idx])
        done += 1
        if n % 10 == 0:
            print('EPOCH: %d | AUC: %.3f | AUC GNN_EXP: %.3f | ITR AUC: %.3f' % (n, auc/done, auc_gnn_exp/done, itr_auc))
    except:
        pass

 print('FINAL | AUC: %.3f | AUC GNN_EXP: %.3f | ITR AUC: %.3f' % (auc/done, auc_gnn_exp/done, itr_auc))
def test_gnn_explainer(model):
    explainer = GNNExplainer(model, log=False)
    assert explainer.__repr__() == 'GNNExplainer()'

    x = torch.randn(8, 3)
    edge_index = torch.tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
                               [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]])
    y = torch.randint(0, 6, (8, ), dtype=torch.long)

    node_feat_mask, edge_mask = explainer.explain_node(2, x, edge_index)
    assert node_feat_mask.size() == (x.size(1), )
    assert node_feat_mask.min() >= 0 and node_feat_mask.max() <= 1
    assert edge_mask.size() == (edge_index.size(1), )
    assert edge_mask.min() >= 0 and edge_mask.max() <= 1

    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, threshold=0.5)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=None)
    explainer.visualize_subgraph(2, edge_index, edge_mask, y=y, threshold=0.5)
Example #15
0
        super(Net, self).__init__()
        self.lin = Sequential(Linear(10, 10))
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
x, edge_index = data.x, data.edge_index

for epoch in range(1, 201):
    model.train()
    optimizer.zero_grad()
    log_logits = model(x, edge_index)
    loss = F.nll_loss(log_logits[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

explainer = GNNExplainer(model, epochs=200)
node_idx = 10
node_feat_mask, edge_mask = explainer.explain_node(node_idx, x, edge_index)
plt = explainer.visualize_subgraph(node_idx, edge_index, edge_mask, y=data.y)
# plt.show()
# equal to the number of samples. Therefore, we use unsqueeze(0).
captum_model = to_captum(model, mask_type='edge', output_idx=output_idx)
edge_mask = torch.ones(data.num_edges, requires_grad=True, device=device)

ig = IntegratedGradients(captum_model)
ig_attr_edge = ig.attribute(edge_mask.unsqueeze(0),
                            target=target,
                            additional_forward_args=(data.x, data.edge_index),
                            internal_batch_size=1)

# Scale attributions to [0, 1]:
ig_attr_edge = ig_attr_edge.squeeze(0).abs()
ig_attr_edge /= ig_attr_edge.max()

# Visualize absolute values of attributions with GNNExplainer visualizer
explainer = GNNExplainer(model)  # TODO: Change to general Explainer visualizer
ax, G = explainer.visualize_subgraph(output_idx, data.edge_index, ig_attr_edge)
plt.show()

# Node explainability
# ===================

captum_model = to_captum(model, mask_type='node', output_idx=output_idx)

ig = IntegratedGradients(captum_model)
ig_attr_node = ig.attribute(data.x.unsqueeze(0),
                            target=target,
                            additional_forward_args=(data.edge_index),
                            internal_batch_size=1)

# Scale attributions to [0, 1]:
    test_correct = int((pred[test_idx] == data.y[test_idx]).sum())
    test_acc = test_correct / test_idx.size(0)

    return train_acc, test_acc


for epoch in range(1, 2001):
    loss = train()
    if epoch % 200 == 0:
        train_acc, test_acc = test()
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Test: {test_acc:.4f}')

model.eval()
targets, preds = [], []
expl = GNNExplainer(model, epochs=300, return_type='raw', log=False)

# Explanation ROC AUC over all test nodes:
self_loop_mask = data.edge_index[0] != data.edge_index[1]
for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()):
    _, expl_edge_mask = expl.explain_node(node_idx,
                                          data.x,
                                          data.edge_index,
                                          edge_weight=data.edge_weight)
    subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index)
    expl_edge_mask = expl_edge_mask[self_loop_mask]
    subgraph_edge_mask = subgraph[3][self_loop_mask]
    targets.append(data.edge_label[subgraph_edge_mask].cpu())
    preds.append(expl_edge_mask[subgraph_edge_mask].cpu())

auc = roc_auc_score(torch.cat(targets), torch.cat(preds))