Пример #1
0
def test_k_hop_subgraph():
    edge_index = torch.tensor([
        [0, 1, 2, 3, 4, 5],
        [2, 2, 4, 4, 6, 6],
    ])

    subset, edge_index, mapping, edge_mask = k_hop_subgraph(
        6, 2, edge_index, relabel_nodes=True)
    assert subset.tolist() == [2, 3, 4, 5, 6]
    assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]]
    assert mapping.tolist() == [4]
    assert edge_mask.tolist() == [False, False, True, True, True, True]

    edge_index = torch.tensor([
        [1, 2, 4, 5],
        [0, 1, 5, 6],
    ])

    subset, edge_index, mapping, edge_mask = k_hop_subgraph([0, 6], 2,
                                                            edge_index,
                                                            relabel_nodes=True)

    assert subset.tolist() == [0, 1, 2, 4, 5, 6]
    assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]]
    assert mapping.tolist() == [0, 5]
    assert edge_mask.tolist() == [True, True, True, True]
Пример #2
0
def sub_data_maker(data_name):

    name = 'Sub_{}'.format(data_name)
    path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', data_name)
    if data_name == "Flickr":
        dataset = Flickr(path, data_name)
    else:
        dataset = Yelp(path, data_name)

    print(data_name)
    start = time.perf_counter()
    f_data = []
    for i in range(0, 80000, 100):
        adj = k_hop_subgraph(i, 1, dataset.data.edge_index,
                             relabel_nodes=True)[1]
        index = k_hop_subgraph(i,
                               1,
                               dataset.data.edge_index,
                               relabel_nodes=True)[0].numpy()
        feature = dataset.data.x[index]
        label = dataset.data.y[index]
        data = Data(x=feature, edge_index=adj, y=label)
        f_data.append(data)

    os.makedirs('./data/{}/processed'.format(name), exist_ok=True)
    torch.save(f_data, './data/{}/processed/data.pt'.format(name))

    end = time.perf_counter()
    print("time consuming {:.2f}".format(end - start))

    print(f_data[:10])
Пример #3
0
    def extract_enclosing_subgraphs(self, link_index, edge_index, y):
        data_list = []
        for src, dst in link_index.t().tolist():
            sub_nodes, sub_edge_index, mapping, _ = k_hop_subgraph(
                [src, dst], self.num_hops, edge_index, relabel_nodes=True)
            src, dst = mapping.tolist()

            # Remove target link from the subgraph.
            mask1 = (sub_edge_index[0] != src) | (sub_edge_index[1] != dst)
            mask2 = (sub_edge_index[0] != dst) | (sub_edge_index[1] != src)
            sub_edge_index = sub_edge_index[:, mask1 & mask2]

            # Calculate node labeling.
            z = self.drnl_node_labeling(sub_edge_index,
                                        src,
                                        dst,
                                        num_nodes=sub_nodes.size(0))

            data = Data(x=self.data.x[sub_nodes],
                        z=z,
                        edge_index=sub_edge_index,
                        y=y)
            data_list.append(data)

        return data_list
Пример #4
0
def explainer2graph(i, explainer, edge_mask, edge_index, threshold, y):
    assert edge_mask.size(0) == edge_index.size(1)
    subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
        int(i),
        explainer.__num_hops__(),
        edge_index,
        relabel_nodes=True,
        num_nodes=None,
        flow=explainer.__flow__())

    edge_mask = edge_mask[hard_edge_mask]

    if threshold is not None:
        edge_mask = (edge_mask >= threshold).to(torch.float)

    if y is None:
        y = torch.zeros(edge_index.max().item() + 1, device=edge_index.device)
    else:
        y = y[subset].to(torch.float) / y.max().item()

    data = Data(edge_index=edge_index, att=edge_mask, y=y,
                num_nodes=y.size(0)).to('cpu')
    G = to_networkx(data, node_attrs=['y'], edge_attrs=['att'])
    mapping = {k: i for k, i in enumerate(subset.tolist())}
    G = nx.relabel_nodes(G, mapping)
    return G
Пример #5
0
 def calcLID(self, x, edge_index):
     
     n_nodes = x.shape[0]
     ic(n_nodes)
     
     if self.khops is None:
         khops = []
         for idx, xi in tqdm(enumerate(x), total=x.shape[0]):
             khop_nidxs, *_ = k_hop_subgraph(idx, num_hops=self.khop, edge_index=edge_index) # [N, ]
             n_khop = khop_nidxs.shape[0]
             xs = (torch.ones(n_khop) * idx).view(1, -1).long().to(x) # [N, ]
             khop_edges = torch.cat([xs, khop_nidxs.view(1, -1)], dim=0) # [2, N]
             khops.append(khop_edges)
         khops = torch.cat(khops, dim=-1) # [2, N * N_NODES]
         khops, _ = tg.utils.remove_self_loops(khops)
         self.khops = khops.long()
         self._save_subgraphs()
     
     row, col = self.khops # [2, N * N_NODES]
     dist = (x[row] - x[col]).norm(dim=-1)
     dense_A = tg.utils.to_dense_adj(self.khops, edge_attr=dist).squeeze(0) # N * N
     
     dense_A = (dense_A == 0.).float() * 1e10 + dense_A
     
     topK = torch.topk(dense_A, k=self.knn, largest=False, dim=-1).values # N * K
     v_log = torch.log(topK.transpose(0, 1) / topK.transpose(0, 1)[-1] + 1e-8)
     v_log = v_log.transpose(0, 1).sum(dim=-1) # => [N, K] => [N]
     lid = - self.knn / v_log
     return lid.mean()
Пример #6
0
    def __subgraph__(self, node_idx, x, edge_index, **kwargs):
        num_nodes, num_edges = x.size(0), edge_index.size(1)

        if node_idx is not None:
            subset, edge_index, mapping, edge_mask = k_hop_subgraph(
                node_idx,
                self.num_hops,
                edge_index,
                relabel_nodes=True,
                num_nodes=num_nodes,
                flow=self.__flow__())
            x = x[subset]
            for key, item in kwargs.items():
                if torch.is_tensor(item) and item.size(0) == num_nodes:
                    item = item[subset]
                elif torch.is_tensor(item) and item.size(0) == num_edges:
                    item = item[edge_mask]
                kwargs[key] = item
        else:
            x = x
            edge_index = edge_index
            row, col = edge_index
            edge_mask = row.new_empty(row.size(0), dtype=torch.bool)
            edge_mask[:] = True
            mapping = None

        return x, edge_index, mapping, edge_mask, kwargs
Пример #7
0
    def subgraph(self, node_idx: int, x: Tensor, edge_index: Tensor, **kwargs):
        r"""Returns the subgraph of the given node.

        Args:
            node_idx (int): The node to explain.
            x (Tensor): The node feature matrix.
            edge_index (LongTensor): The edge indices.
            **kwargs (optional): Additional arguments passed to the GNN module.

        :rtype: (Tensor, Tensor, LongTensor, LongTensor, LongTensor, dict)
        """
        num_nodes, num_edges = x.size(0), edge_index.size(1)
        subset, edge_index, mapping, edge_mask = k_hop_subgraph(
            node_idx,
            self.num_hops,
            edge_index,
            relabel_nodes=True,
            num_nodes=num_nodes,
            flow=self._flow())

        x = x[subset]
        kwargs_new = {}
        for key, value in kwargs.items():
            if torch.is_tensor(value) and value.size(0) == num_nodes:
                kwargs_new[key] = value[subset]
            elif torch.is_tensor(value) and value.size(0) == num_edges:
                kwargs_new[key] = value[edge_mask]
            else:
                kwargs_new[key] = value  # TODO: this is not in PGExplainer
        return x, edge_index, mapping, edge_mask, subset, kwargs_new
    def __init__(self, model, node_idx: int, k: int, x, edge_index, sharp: float = 0.01, splines: int = 6, sigmoid = True):
        #device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        device = torch.device('cpu')
        self.model = model.to(device)
        self.x = x.to(device)
        self.edge_index = edge_index.to(device)
        self.node_idx = node_idx
        self.k = k
        self.sharp = sharp
        self.subset, self.edge_index_adj, self.mapping, self.edge_mask_hard = k_hop_subgraph(
            self.node_idx, k, self.edge_index, relabel_nodes=True)
        self.x_adj = self.x[self.subset]
        self.device = device

        with torch.no_grad():
            self.preds = model(self.x, self.edge_index_adj)
        
        self.N = self.edge_index_adj.size(1)
        self.base_dist = dist.Normal(torch.zeros(self.N).to(device), torch.ones(self.N).to(device))
        self.splines = []
        for i in range(splines):
            self.splines.append(T.spline(self.N).to(device))
        self.flow_dist = dist.TransformedDistribution(self.base_dist,self.splines)

        self.sigmoid = sigmoid
def test_heterogeneous_neighbor_loader_on_cora(directed):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]
    data.edge_weight = torch.rand(data.num_edges)

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].n_id = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index
    hetero_data['paper', 'paper'].edge_weight = data.edge_weight

    split_idx = torch.arange(5, 8)

    loader = NeighborLoader(hetero_data,
                            num_neighbors=[-1, -1],
                            batch_size=split_idx.numel(),
                            input_nodes=('paper', split_idx),
                            directed=directed)
    assert len(loader) == 1

    hetero_batch = next(iter(loader))
    batch_size = hetero_batch['paper'].batch_size

    if not directed:
        n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                            num_hops=2,
                                            edge_index=data.edge_index,
                                            num_nodes=data.num_nodes)

        n_id = n_id.sort()[0]
        assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()
        assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)
    hetero_model = to_hetero(model, hetero_data.metadata())

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
                        hetero_batch.edge_weight_dict)['paper'][:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)

    try:
        shutil.rmtree(root)
    except PermissionError:
        pass
Пример #10
0
    def ego_subgraph(self):
        edge_index = np.asarray(self.ori_adj.nonzero())
        edge_index = torch.as_tensor(edge_index,
                                     dtype=torch.long,
                                     device=self.device)
        sub_nodes, sub_edges, *_ = k_hop_subgraph(int(self.target_node),
                                                  self.K, edge_index)
        sub_edges = sub_edges[:, sub_edges[0] < sub_edges[1]]

        return sub_nodes, sub_edges
Пример #11
0
def test_k_hop_subgraph():
    edge_index = torch.tensor([[0, 1, 2, 3, 4, 5], [2, 2, 4, 4, 6, 6]])

    subset, edge_index, edge_mask = k_hop_subgraph(6,
                                                   2,
                                                   edge_index,
                                                   relabel_nodes=True)
    assert subset.tolist() == [6, 2, 3, 4, 5]
    assert edge_index.tolist() == [[1, 2, 3, 4], [3, 3, 0, 0]]
    assert edge_mask.tolist() == [False, False, True, True, True, True]
Пример #12
0
def extract_subgraph(node_idx, num_hops, edge_index):
    if num_hops == 0:
        sub = [True for i in range(edge_index.shape[1])]
        for i in range(edge_index.shape[1]):
            sub[i] = sub[i] and (edge_index[0, i] == node_idx
                                 or edge_index[1, i] == node_idx)
        return edge_index[:, sub], node_idx
    else:
        nodes, new_edge_index, mapping, _ = k_hop_subgraph(
            node_idx, num_hops, edge_index)
        return new_edge_index, node_idx
Пример #13
0
def test_hgt_loader_on_cora(get_dataset):
    dataset = get_dataset(name='Cora')
    data = dataset[0]
    data.edge_weight = torch.rand(data.num_edges)

    hetero_data = HeteroData()
    hetero_data['paper'].x = data.x
    hetero_data['paper'].n_id = torch.arange(data.num_nodes)
    hetero_data['paper', 'paper'].edge_index = data.edge_index
    hetero_data['paper', 'paper'].edge_weight = data.edge_weight

    split_idx = torch.arange(5, 8)

    # Sample the complete two-hop neighborhood:
    loader = HGTLoader(hetero_data,
                       num_samples=[data.num_nodes] * 2,
                       batch_size=split_idx.numel(),
                       input_nodes=('paper', split_idx))
    assert len(loader) == 1

    hetero_batch = next(iter(loader))
    batch_size = hetero_batch['paper'].batch_size

    n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                        num_hops=2,
                                        edge_index=data.edge_index,
                                        num_nodes=data.num_nodes)

    n_id = n_id.sort()[0]
    assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist()
    assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)
    hetero_model = to_hetero(model, hetero_data.metadata())

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict,
                        hetero_batch.edge_weight_dict)['paper'][:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)
Пример #14
0
    def __subgraph__(self, node_idx, x, edge_index, **kwargs):
        num_nodes, num_edges = x.size(0), edge_index.size(1)

        subset, edge_index, mapping, edge_mask = k_hop_subgraph(
            node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
            num_nodes=num_nodes, flow=self.__flow__())

        x = x[subset]
        for key, item in kwargs.items():
            if torch.is_tensor(item) and item.size(0) == num_nodes:
                item = item[subset]
            elif torch.is_tensor(item) and item.size(0) == num_edges:
                item = item[edge_mask]
            kwargs[key] = item

        return x, edge_index, mapping, edge_mask, kwargs
def test_homogeneous_neighbor_loader_on_cora(directed):
    root = osp.join('/', 'tmp', str(random.randrange(sys.maxsize)))
    dataset = Planetoid(root, 'Cora')
    data = dataset[0]
    data.n_id = torch.arange(data.num_nodes)
    data.edge_weight = torch.rand(data.num_edges)

    split_idx = torch.arange(5, 8)

    loader = NeighborLoader(data,
                            num_neighbors=[-1, -1],
                            batch_size=split_idx.numel(),
                            input_nodes=split_idx,
                            directed=directed)
    assert len(loader) == 1

    batch = next(iter(loader))
    batch_size = batch.batch_size

    if not directed:
        n_id, _, _, e_mask = k_hop_subgraph(split_idx,
                                            num_hops=2,
                                            edge_index=data.edge_index,
                                            num_nodes=data.num_nodes)

        assert n_id.sort()[0].tolist() == batch.n_id.sort()[0].tolist()
        assert batch.num_edges == int(e_mask.sum())

    class GNN(torch.nn.Module):
        def __init__(self, in_channels, hidden_channels, out_channels):
            super().__init__()
            self.conv1 = GraphConv(in_channels, hidden_channels)
            self.conv2 = GraphConv(hidden_channels, out_channels)

        def forward(self, x, edge_index, edge_weight):
            x = self.conv1(x, edge_index, edge_weight).relu()
            x = self.conv2(x, edge_index, edge_weight).relu()
            return x

    model = GNN(dataset.num_features, 16, dataset.num_classes)

    out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx]
    out2 = model(batch.x, batch.edge_index, batch.edge_weight)[:batch_size]
    assert torch.allclose(out1, out2, atol=1e-6)

    shutil.rmtree(root)
Пример #16
0
def calcLIDWithoutGrad(x, edge_index, khop: int, knn: int):
    n_nodes = x.shape[0]
    ic(n_nodes)
    # idxs = torch.arange(0, n_nodes).long()
    total_lid = torch.tensor(0.).to(x)

    for idx, xi in tqdm(enumerate(x)):
        khop_nidxs, _, mapping, _ = k_hop_subgraph(idx, num_hops=khop, edge_index=edge_index)
        khop_feat = x[khop_nidxs]
        
        # Not have to be differentiable! use tg.knn() method
        knn_idx = tg.nn.knn(khop_feat, xi.view(1, -1), k=knn + 1)[1][1:]
        # ic(knn_idx)
        knn_feat = khop_feat[knn_idx] # => [N]
        knn_dist = (knn_feat - xi).norm(dim=-1)
        # ic(knn_dist)
        lid = (knn_dist / torch.max(knn_dist)).log().sum() / knn
        lid = - lid ** -1.
        # ic(lid)
        total_lid += lid
    return total_lid / n_nodes
    def __init__(self,
                 model,
                 node_idx: int,
                 k: int,
                 x,
                 edge_index,
                 sharp: float = 0.01):
        device = torch.device('cpu')
        self.model = model.to(device)
        self.x = x.to(device)
        self.edge_index = edge_index.to(device)
        self.node_idx = node_idx
        self.k = k
        self.sharp = sharp
        self.subset, self.edge_index_adj, self.mapping, self.edge_mask_hard = k_hop_subgraph(
            self.node_idx, k, self.edge_index, relabel_nodes=True)
        self.x_adj = self.x[self.subset]
        self.device = device

        with torch.no_grad():
            self.preds = model(self.x, self.edge_index_adj)

        self.N = self.edge_index_adj.size(1)
Пример #18
0
# Trade memory consumption for faster computation.
if args.dataset in ['AIFB', 'MUTAG']:
    RGCNConv = FastRGCNConv

path = osp.join(osp.dirname(osp.realpath(__file__)), '..', 'data', 'Entities')
dataset = Entities(path, args.dataset)
data = dataset[0]

# BGS and AM graphs are too big to process them in a full-batch fashion.
# Since our model does only make use of a rather small receptive field, we
# filter the graph to only contain the nodes that are at most 2-hop neighbors
# away from any training/test node.
node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)
node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(node_idx,
                                                          2,
                                                          data.edge_index,
                                                          relabel_nodes=True)

data.num_nodes = node_idx.size(0)
data.edge_index = edge_index
data.edge_type = data.edge_type[edge_mask]
data.train_idx = mapping[:data.train_idx.size(0)]
data.test_idx = mapping[data.train_idx.size(0):]


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = RGCNConv(data.num_nodes,
                              16,
                              dataset.num_relations,
Пример #19
0
def visualize_subgraph2(explainer, node_idx, edge_index, edge_mask, y=None, threshold=None, only_topk_edges=None, **kwargs):
    r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
    :attr:`edge_mask`.

    Args:
        node_idx (int): The node id to explain.
        edge_index (LongTensor): The edge indices.
        edge_mask (Tensor): The edge mask.
        y (Tensor, optional): The ground-truth node-prediction labels used
            as node colorings. (default: :obj:`None`)
        threshold (float, optional): Sets a threshold for visualizing
            important edges. If set to :obj:`None`, will visualize all
            edges with transparancy indicating the importance of edges.
            (default: :obj:`None`)
        **kwargs (optional): Additional arguments passed to
            :func:`nx.draw`.

    :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
    """

    assert edge_mask.size(0) == edge_index.size(1)

    #Only operate on a k-hop subgraph around `node_idx`.
    subset, edge_index2, _, hard_edge_mask = k_hop_subgraph(
        node_idx, explainer.__num_hops__(), edge_index, relabel_nodes=True,
        num_nodes=None, flow=explainer.__flow__())
    
    #edge_mask = edge_mask[hard_edge_mask]
    edge_index2 = edge_index[:, hard_edge_mask]
    edge_tuples = set([(a, b) for a, b in zip(edge_index2.numpy()[0], edge_index2.numpy()[1])])
    #edge_index = edge_index[:, edge_mask.argsort()[-only_topk_edges:]]
    #print('subset', subset)
    #print('edge_mask', edge_mask)
    if only_topk_edges:
        #top_k_nodes = set(edge_index[:, edge_mask.argsort()[-only_topk_edges:]].tolist()[0])
        visible_edges = edge_mask.argsort()[-only_topk_edges:]
        to_be_deleted = []
        for i, v in enumerate(visible_edges):
            if (edge_index.numpy()[0, v], edge_index.numpy()[1, v]) not in edge_tuples:

                to_be_deleted.append(i)
            elif edge_index[0, v] == edge_index[1, v]:
                to_be_deleted.append(i)
       
        #visible_edges = np.delete(visible_edges, to_be_deleted) 

        #print(visible_edges)
        edge_mask = edge_mask[visible_edges]
        edge_index = edge_index[:, visible_edges]
        #print(edge_index)
        chosen_nodes = np.unique(edge_index.flatten())

        subset = chosen_nodes
    else:
        if threshold is not None:
            if threshold < 1:
                edge_mask = (edge_mask >= threshold).to(torch.float)
            else:
                #top N
                edge_mask = np.where(edge_mask >= np.sort(edge_mask)[-threshold], 1, 0)

    if y is None:
        y = torch.zeros(subset.shape[0],
                        device=edge_index.device)
    else:
        y = y[subset].to(torch.float) / y.max().item()
   
    newdata = Data(edge_index=edge_index, att=edge_mask, y=y,
                num_nodes=y.size(0)).to('cpu')
    G = to_networkx(newdata, node_attrs=None, edge_attrs=['att'],  to_undirected=True, remove_self_loops=True)
    '''
    #should be done before relabelling
    if only_topk_edges:
        top_k_nodes = set(edge_index[:, edge_mask.argsort()[-only_topk_edges:]].tolist()[0])
        #iter over G and delete anyting not in top_k nodes
        G.remove_nodes_from(G.nodes() - top_k_nodes)
    '''
    print(edge_index)
    mapping = {k: i for k, i in enumerate(subset.tolist())}
    print(mapping)
    G = nx.relabel_nodes(G, mapping)
    print(G.nodes())
    mapping = {i: U[i].split('_')[1] + '\n' + str(int(all_distances[i])) for i in subset.tolist()}
    G = nx.relabel_nodes(G, mapping)
    print(G.nodes())

    kwargs['with_labels'] = kwargs.get('with_labels') or True
    kwargs['font_size'] = kwargs.get('font_size') or 10
    kwargs['node_size'] = kwargs.get('node_size') or 800
    kwargs['cmap'] = kwargs.get('cmap') or 'cool'


    pos = nx.spring_layout(G)
    #pos = nx.circular_layout(G)
    #pos = nx.spectral_layout(G)
    #pos = nx.shell_layout(G)
    ax = plt.gca()
    
    for source, target, data in G.edges(data=True):
        ax.annotate(
            '', xy=pos[target], xycoords='data', xytext=pos[source],
            textcoords='data', arrowprops=dict(
                arrowstyle="-",
                alpha=max(data['att'], 0.1),
                shrinkA=sqrt(kwargs['node_size']) / 2.0,
                shrinkB=sqrt(kwargs['node_size']) / 2.0,
                connectionstyle="arc3,rad=0.1",
            ))
    #pdb.set_trace()
    #label_mapping = {n: n + '-' + userLocation[n] for n in G.nodes()}
    #G = nx.relabel_nodes(G, mapping)
    nx.draw_networkx_nodes(G, pos, node_color=y.tolist(), **kwargs)
    nx.draw_networkx_labels(G, pos, **kwargs)

    return ax, G
    return train_acc, test_acc


for epoch in range(1, 2001):
    loss = train()
    if epoch % 200 == 0:
        train_acc, test_acc = test()
        print(f'Epoch: {epoch:04d}, Loss: {loss:.4f}, '
              f'Train: {train_acc:.4f}, Test: {test_acc:.4f}')

model.eval()
targets, preds = [], []
expl = GNNExplainer(model, epochs=300, return_type='raw', log=False)

# Explanation ROC AUC over all test nodes:
self_loop_mask = data.edge_index[0] != data.edge_index[1]
for node_idx in tqdm(data.expl_mask.nonzero(as_tuple=False).view(-1).tolist()):
    _, expl_edge_mask = expl.explain_node(node_idx,
                                          data.x,
                                          data.edge_index,
                                          edge_weight=data.edge_weight)
    subgraph = k_hop_subgraph(node_idx, num_hops=3, edge_index=data.edge_index)
    expl_edge_mask = expl_edge_mask[self_loop_mask]
    subgraph_edge_mask = subgraph[3][self_loop_mask]
    targets.append(data.edge_label[subgraph_edge_mask].cpu())
    preds.append(expl_edge_mask[subgraph_edge_mask].cpu())

auc = roc_auc_score(torch.cat(targets), torch.cat(preds))
print(f'Mean ROC AUC: {auc:.4f}')
Пример #21
0
    def explain(self, node_idx, target, num_samples=100, top_node=None, p_threshold=0.05, pred_threshold=0.1):
        neighbors, _, _, _ = k_hop_subgraph(node_idx, self.num_layers, self.edge_index)
        neighbors = neighbors.cpu().detach().numpy()

        if (node_idx not in neighbors):
            neighbors = np.append(neighbors, node_idx)

        pred_torch = self.model(self.X, self.edge_index).cpu()
        soft_pred = np.asarray([softmax(np.asarray(pred_torch[node_].data)) for node_ in range(self.X.shape[0])])

        pred_node = np.asarray(pred_torch[node_idx].data)
        label_node = np.argmax(pred_node)
        soft_pred_node = softmax(pred_node)

        Samples = []
        Pred_Samples = []

        for iteration in range(num_samples):

            X_perturb = self.X.cpu().detach().numpy()
            sample = []
            for node in neighbors:
                seed = np.random.randint(2)
                if seed == 1:
                    latent = 1
                    X_perturb = self.perturb_features_on_node(X_perturb, node, random=seed)
                else:
                    latent = 0
                sample.append(latent)

            X_perturb_torch = torch.tensor(X_perturb, dtype=torch.float).to(device)
            pred_perturb_torch = self.model(X_perturb_torch, self.edge_index).cpu()
            soft_pred_perturb = np.asarray(
                [softmax(np.asarray(pred_perturb_torch[node_].data)) for node_ in range(self.X.shape[0])])

            sample_bool = []
            for node in neighbors:
                if (soft_pred_perturb[node, target] + pred_threshold) < soft_pred[node, target]:
                    sample_bool.append(1)
                else:
                    sample_bool.append(0)

            Samples.append(sample)
            Pred_Samples.append(sample_bool)

        Samples = np.asarray(Samples)
        Pred_Samples = np.asarray(Pred_Samples)
        Combine_Samples = Samples - Samples
        for s in range(Samples.shape[0]):
            Combine_Samples[s] = np.asarray(
                [Samples[s, i] * 10 + Pred_Samples[s, i] + 1 for i in range(Samples.shape[1])])

        data = pd.DataFrame(Combine_Samples)
        data = data.rename(columns={0: "A", 1: "B"})  # Trick to use chi_square test on first two data columns
        ind_ori_to_sub = dict(zip(neighbors, list(data.columns)))

        p_values = []
        for node in neighbors:
            chi2, p = chi_square(ind_ori_to_sub[node], ind_ori_to_sub[node_idx], [], data)
            p_values.append(p)

        pgm_stats = dict(zip(neighbors, p_values))

        return pgm_stats
Пример #22
0
    def visualize_subgraph(self,
                           node_idx: Optional[int],
                           edge_index: Tensor,
                           edge_mask: Tensor,
                           y: Optional[Tensor] = None,
                           threshold: Optional[int] = None,
                           edge_y: Optional[Tensor] = None,
                           node_alpha: Optional[Tensor] = None,
                           seed: int = 10,
                           **kwargs):
        r"""Visualizes the subgraph given an edge mask :attr:`edge_mask`.

        Args:
            node_idx (int): The node id to explain.
                Set to :obj:`None` to explain a graph.
            edge_index (LongTensor): The edge indices.
            edge_mask (Tensor): The edge mask.
            y (Tensor, optional): The ground-truth node-prediction labels used
                as node colorings. All nodes will have the same color
                if :attr:`node_idx` is :obj:`-1`.(default: :obj:`None`).
            threshold (float, optional): Sets a threshold for visualizing
                important edges. If set to :obj:`None`, will visualize all
                edges with transparancy indicating the importance of edges.
                (default: :obj:`None`)
            edge_y (Tensor, optional): The edge labels used as edge colorings.
            node_alpha (Tensor, optional): Tensor of floats (0 - 1) indicating
                transparency of each node.
            seed (int, optional): Random seed of the :obj:`networkx` node
                placement algorithm. (default: :obj:`10`)
            **kwargs (optional): Additional arguments passed to
                :func:`nx.draw`.

        :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
        """
        import matplotlib.pyplot as plt
        import networkx as nx

        assert edge_mask.size(0) == edge_index.size(1)

        if node_idx is None or node_idx < 0:
            hard_edge_mask = torch.BoolTensor([True] * edge_index.size(1),
                                              device=edge_mask.device)
            subset = torch.arange(edge_index.max().item() + 1,
                                  device=edge_index.device)
            y = None

        else:
            # Only operate on a k-hop subgraph around `node_idx`.
            subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
                node_idx,
                self.num_hops,
                edge_index,
                relabel_nodes=True,
                num_nodes=None,
                flow=self._flow())

        edge_mask = edge_mask[hard_edge_mask]

        if threshold is not None:
            edge_mask = (edge_mask >= threshold).to(torch.float)

        if y is None:
            y = torch.zeros(edge_index.max().item() + 1,
                            device=edge_index.device)
        else:
            y = y[subset].to(torch.float) / y.max().item()

        if edge_y is None:
            edge_color = ['black'] * edge_index.size(1)
        else:
            colors = list(plt.rcParams['axes.prop_cycle'])
            edge_color = [
                colors[i % len(colors)]['color']
                for i in edge_y[hard_edge_mask]
            ]

        data = Data(edge_index=edge_index,
                    att=edge_mask,
                    edge_color=edge_color,
                    y=y,
                    num_nodes=y.size(0)).to('cpu')
        G = to_networkx(data,
                        node_attrs=['y'],
                        edge_attrs=['att', 'edge_color'])
        mapping = {k: i for k, i in enumerate(subset.tolist())}
        G = nx.relabel_nodes(G, mapping)

        node_args = set(signature(nx.draw_networkx_nodes).parameters.keys())
        node_kwargs = {k: v for k, v in kwargs.items() if k in node_args}
        node_kwargs['node_size'] = kwargs.get('node_size') or 800
        node_kwargs['cmap'] = kwargs.get('cmap') or 'cool'

        label_args = set(signature(nx.draw_networkx_labels).parameters.keys())
        label_kwargs = {k: v for k, v in kwargs.items() if k in label_args}
        label_kwargs['font_size'] = kwargs.get('font_size') or 10

        pos = nx.spring_layout(G, seed=seed)
        ax = plt.gca()
        for source, target, data in G.edges(data=True):
            ax.annotate('',
                        xy=pos[target],
                        xycoords='data',
                        xytext=pos[source],
                        textcoords='data',
                        arrowprops=dict(
                            arrowstyle="->",
                            alpha=max(data['att'], 0.1),
                            color=data['edge_color'],
                            shrinkA=sqrt(node_kwargs['node_size']) / 2.0,
                            shrinkB=sqrt(node_kwargs['node_size']) / 2.0,
                            connectionstyle="arc3,rad=0.1",
                        ))

        if node_alpha is None:
            nx.draw_networkx_nodes(G,
                                   pos,
                                   node_color=y.tolist(),
                                   **node_kwargs)
        else:
            node_alpha_subset = node_alpha[subset]
            assert ((node_alpha_subset >= 0) & (node_alpha_subset <= 1)).all()
            nx.draw_networkx_nodes(G,
                                   pos,
                                   alpha=node_alpha_subset.tolist(),
                                   node_color=y.tolist(),
                                   **node_kwargs)

        nx.draw_networkx_labels(G, pos, **label_kwargs)

        return ax, G
Пример #23
0
    def visualize_subgraph(self,
                           node_idx,
                           dataset,
                           edge_mask,
                           pos=None,
                           y=None,
                           show=True,
                           save=False,
                           verbose=True,
                           threshold=None,
                           **kwargs):

        edge_index = dataset.edge_index

        assert edge_mask.size(0) == edge_index.size(1)

        if threshold is not None:
            print('Edge Threshold:', threshold)
            edge_mask = torch.tensor(edge_mask >= threshold, dtype=torch.float)

        if node_idx is not None:
            # Only operate on a k-hop subgraph around `node_idx`.
            subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
                node_idx,
                self.num_hops,
                edge_index,
                relabel_nodes=True,
                num_nodes=None,
                flow=self.__flow__())
            edge_mask = edge_mask[hard_edge_mask]

        else:
            subset = []
            for index, mask in enumerate(edge_mask):
                node_a = edge_index[0, index]
                node_b = edge_index[1, index]
                if node_a not in subset:
                    subset.append(node_a.cpu().item())
                if node_b not in subset:
                    subset.append(node_b.cpu().item())

        if y is None:
            y = torch.zeros(edge_index.max().item() + 1,
                            device=edge_index.device)
        else:
            y = y[subset].to(torch.float) / y.max().item()

        data = Data(edge_index=edge_index,
                    att=edge_mask,
                    y=y,
                    num_nodes=y.size(0)).to('cpu')

        G = to_networkx(data, node_attrs=['y'], edge_attrs=['att'])
        mapping = {k: i for k, i in enumerate(subset.tolist())}
        G = nx.relabel_nodes(G, mapping)

        if verbose:
            print("Node: ", node_idx, "; Label:", dataset.y[node_idx].item())
            print("Related nodes:", G.nodes)
            print("Related edges:")
            for source, target, graph_data in G.edges(data=True):
                if graph_data['att'] > 0.95:
                    print('{}->{}: {}'.format(source, target,
                                              graph_data['att']))
            for node in G.nodes:
                print("Node:", node, "; Label:", dataset.y[node].item(),
                      "; Marker:", dataset.X[node])

        if show:
            kwargs['with_labels'] = kwargs.get('with_labels') or True
            kwargs['font_size'] = kwargs.get('font_size') or 10
            kwargs['node_size'] = kwargs.get('node_size') or 200
            kwargs['cmap'] = kwargs.get('cmap') or 'Set3'

            if pos is None:
                pos = nx.spring_layout(G)
            ax = plt.gca()

            for source, target, graph_data in G.edges(data=True):
                ax.annotate('',
                            xy=pos[target],
                            xycoords='data',
                            xytext=pos[source],
                            textcoords='data',
                            arrowprops=dict(
                                arrowstyle="->",
                                alpha=max(graph_data['att'], 0.1),
                                shrinkA=sqrt(1000) / 2.0,
                                shrinkB=sqrt(1000) / 2.0,
                                connectionstyle="arc3,rad=0.1",
                            ))

            nx.draw_networkx_nodes(G, pos, node_color=y.flatten(), **kwargs)
            nx.draw_networkx_labels(G, pos, **kwargs)

            if save:
                plt.savefig('plot/sample')
            plt.show()

        return G
Пример #24
0
def extract_subgraph(node_idx, num_hops, edge_index):
    nodes, new_edge_index, mapping, _ = k_hop_subgraph(node_idx, num_hops,
                                                       edge_index)
    return new_edge_index, node_idx
Пример #25
0
    def visualize_subgraph(self, node_idx, edge_index, edge_mask, y=None,
                           threshold=None, **kwargs):
        r"""Visualizes the subgraph around :attr:`node_idx` given an edge mask
        :attr:`edge_mask`.

        Args:
            node_idx (int): The node id to explain.
            edge_index (LongTensor): The edge indices.
            edge_mask (Tensor): The edge mask.
            y (Tensor, optional): The ground-truth node-prediction labels used
                as node colorings. (default: :obj:`None`)
            threshold (float, optional): Sets a threshold for visualizing
                important edges. If set to :obj:`None`, will visualize all
                edges with transparancy indicating the importance of edges.
                (default: :obj:`None`)
            **kwargs (optional): Additional arguments passed to
                :func:`nx.draw`.

        :rtype: :class:`matplotlib.axes.Axes`, :class:`networkx.DiGraph`
        """

        assert edge_mask.size(0) == edge_index.size(1)

        # Only operate on a k-hop subgraph around `node_idx`.
        subset, edge_index, _, hard_edge_mask = k_hop_subgraph(
            node_idx, self.__num_hops__(), edge_index, relabel_nodes=True,
            num_nodes=None, flow=self.__flow__())

        edge_mask = edge_mask[hard_edge_mask]

        if threshold is not None:
            edge_mask = (edge_mask >= threshold).to(torch.float)

        if y is None:
            y = torch.zeros(edge_index.max().item() + 1,
                            device=edge_index.device)
        else:
            y = y[subset].to(torch.float) / y.max().item()

        data = Data(edge_index=edge_index, att=edge_mask, y=y,
                    num_nodes=y.size(0)).to('cpu')
        G = to_networkx(data, node_attrs=['y'], edge_attrs=['att'])
        mapping = {k: i for k, i in enumerate(subset.tolist())}
        G = nx.relabel_nodes(G, mapping)

        # kwargs['with_labels'] = kwargs.get('with_labels') or True
        kwargs['with_labels'] = False
        kwargs['font_size'] = kwargs.get('font_size') or 5
        kwargs['node_size'] = kwargs.get('node_size') or 800
        kwargs['cmap'] = kwargs.get('cmap') or 'RdYlGn'

        pos = nx.spring_layout(G)
        ax = plt.gca()
        for source, target, data in G.edges(data=True):
            ax.annotate(
                '', xy=pos[target], xycoords='data', xytext=pos[source],
                textcoords='data', arrowprops=dict(
                    arrowstyle="->",
                    alpha=max(data['att'], 0.1),
                    shrinkA=sqrt(kwargs['node_size']) / 2.0,
                    shrinkB=sqrt(kwargs['node_size']) / 2.0,
                    connectionstyle="arc3,rad=0.1",
                ))
        nx.draw_networkx_nodes(G, pos, node_color=y.tolist(), alpha = 0.5, **kwargs)
        nx.draw_networkx_labels(G, pos, **kwargs)

        return ax, G
Пример #26
0
def get_data_sample(G,
                    set_index,
                    hop_num,
                    feature_flags,
                    max_sprw,
                    label,
                    debug=False):
    # first, extract subgraph
    set_index = list(set_index)
    sp_flag, rw_flag = feature_flags
    max_sp, rw_depth = max_sprw
    if len(set_index) > 1:
        G = G.copy()
        G.remove_edges_from(combinations(set_index, 2))
    edge_index = torch.tensor(list(G.edges)).long().t().contiguous()
    edge_index = torch.cat([edge_index, edge_index[[1, 0], ]], dim=-1)
    subgraph_node_old_index, new_edge_index, new_set_index, edge_mask = tgu.k_hop_subgraph(
        torch.tensor(set_index).long(),
        hop_num,
        edge_index,
        num_nodes=G.number_of_nodes(),
        relabel_nodes=True)

    # reconstruct networkx graph object for the extracted subgraph
    num_nodes = subgraph_node_old_index.size(0)
    new_G = nx.from_edgelist(new_edge_index.t().numpy().astype(dtype=np.int32),
                             create_using=type(G))
    new_G.add_nodes_from(np.arange(
        num_nodes, dtype=np.int32))  # to add disconnected nodes
    assert (new_G.number_of_nodes() == num_nodes)

    # Construct x from x_list
    x_list = []
    attributes = G.graph['attributes']
    if attributes is not None:
        new_attributes = torch.tensor(
            attributes, dtype=torch.float32)[subgraph_node_old_index]
        if new_attributes.dim() < 2:
            new_attributes.unsqueeze_(1)
        x_list.append(new_attributes)
    # if deg_flag:
    #     x_list.append(torch.log(tgu.degree(new_edge_index[0], num_nodes=num_nodes, dtype=torch.float32).unsqueeze(1)+1))
    if sp_flag:
        features_sp_sample = get_features_sp_sample(new_G,
                                                    new_set_index.numpy(),
                                                    max_sp=max_sp)
        features_sp_sample = torch.from_numpy(features_sp_sample).float()
        x_list.append(features_sp_sample)
    if rw_flag:
        adj = np.asarray(
            nx.adjacency_matrix(
                new_G,
                nodelist=np.arange(new_G.number_of_nodes(),
                                   dtype=np.int32)).todense().astype(
                                       np.float32))  # [n_nodes, n_nodes]
        features_rw_sample = get_features_rw_sample(adj,
                                                    new_set_index.numpy(),
                                                    rw_depth=rw_depth)
        features_rw_sample = torch.from_numpy(features_rw_sample).float()
        x_list.append(features_rw_sample)

    x = torch.cat(x_list, dim=-1)
    y = torch.tensor([label],
                     dtype=torch.long) if label is not None else torch.tensor(
                         [0], dtype=torch.long)
    new_set_index = new_set_index.long().unsqueeze(0)
    if not debug:
        return Data(x=x,
                    edge_index=new_edge_index,
                    y=y,
                    set_indices=new_set_index)
    else:
        return Data(
            x=x,
            edge_index=new_edge_index,
            y=y,
            set_indices=new_set_index,
            old_set_indices=torch.tensor(set_index).long().unsqueeze(0),
            old_subgraph_indices=subgraph_node_old_index)
def get_data_sample(G, set_index, edge_index, num_hop, feature_flags, max_sprw,
                    label):
    set_index = list(set_index)
    sp_flag, rw_flag = feature_flags
    max_sp, max_rw = max_sprw

    # extract subgraph from the root node with num_hop; for node classification, len(set_index)=1
    subgraph_node_old_index, new_edge_index, new_set_index, edge_mask = tgu.k_hop_subgraph(
        torch.tensor(set_index).long(),
        num_hop,
        edge_index,
        num_nodes=G.number_of_nodes(),
        relabel_nodes=True)

    # reconstruct networkx graph object for the extracted subgraph
    num_nodes = subgraph_node_old_index.size(0)
    new_G = nx.from_edgelist(new_edge_index.t().numpy().astype(dtype=np.int32),
                             create_using=type(G))
    new_G.add_nodes_from(np.arange(
        num_nodes, dtype=np.int32))  # to add disconnected nodes
    assert (new_G.number_of_nodes() == num_nodes)

    # assemble x from features to x_list
    x_list = []
    attributes = G.graph['attributes']
    if attributes is not None:
        new_attributes = torch.tensor(
            attributes, dtype=torch.float32)[subgraph_node_old_index]
        if new_attributes.dim() < 2:
            new_attributes.unsqueeze_(1)
        x_list.append(new_attributes)

    if sp_flag:
        features_sp_sample = gen_sp_features(new_G,
                                             new_set_index.numpy(),
                                             max_sp=max_sp)
        features_sp_sample = torch.from_numpy(features_sp_sample).float()
        x_list.append(features_sp_sample)

    if rw_flag:
        # use sparse matrix for computing the landing probabilities [n_nodes, n_nodes]
        adj = nx.adjacency_matrix(new_G,
                                  nodelist=np.arange(new_G.number_of_nodes(),
                                                     dtype=np.int32))
        features_rw_sample = gen_rw_features(adj,
                                             new_set_index.numpy(),
                                             rw_depth=max_rw)
        features_rw_sample = torch.from_numpy(features_rw_sample).float()
        x_list.append(features_rw_sample)

    x = torch.cat(x_list, dim=-1)
    y = torch.tensor([label],
                     dtype=torch.long) if label is not None else torch.tensor(
                         [0], dtype=torch.long)
    new_set_index = new_set_index.long().unsqueeze(0)

    return Data(x=x,
                edge_index=new_edge_index,
                y=y,
                set_indices=new_set_index,
                old_set_indices=torch.tensor(set_index).long().unsqueeze(0))