class GGCN(torch.nn.Module): def __init__(self, num_layers=2, hidden=200, features_num=16, num_class=2, dropout=0.5): super(GGCN, self).__init__() self.conv1 = GraphConv(features_num, hidden, aggr='add') self.lin2 = Linear(hidden, num_class) self.dropout = dropout print("hidden=%d, dropout=%f" % (hidden, self.dropout)) def reset_parameters(self): self.conv1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = F.relu(self.conv1(x, edge_index, edge_weight=edge_weight)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class ASAP(torch.nn.Module): def __init__(self, num_classes, num_features, num_layers, hidden, ratio=0.8, dropout=0): super(ASAP, self).__init__() self.conv1 = GraphConv(num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend([ ASAPooling(hidden, ratio, dropout=dropout) for i in range((num_layers) // 2) ]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch edge_weight = None x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight) x = F.relu(x) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, edge_weight, batch, _ = pool( x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class Net(torch.nn.Module): def __init__(self): super().__init__() hidden = args.hidden num_layers = 5 ratio = 0.8 self.conv1 = GraphConv(dataset.num_features, hidden, aggr='add') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='add') for i in range(num_layers - 1) ]) self.pools.extend( [TopKPooling(hidden, ratio) for i in range((num_layers) // 2)]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_add_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_add_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, _, batch, _, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
class Graclus(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(Graclus, self).__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='mean')) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class TopK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(TopK, self).__init__() self.conv1 = GraphConv(dataset.num_features, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='mean')) self.pools.append(TopKPooling(hidden, ratio=0.8)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv, pool in zip(self.convs, self.pools): conv.reset_parameters() pool.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, (conv, pool) in enumerate(zip(self.convs, self.pools)): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if i % 2 == 0: x, edge_index, _, batch, _ = pool(x, edge_index, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GeoGraph(torch.nn.Module): def __init__(self, hidden, geo, training): super(GeoGraph, self).__init__() self.geo = geo self.training = training if self.geo == True: self.conv1 = GraphConv(64 + 3, hidden, aggr='mean') else: self.conv1 = GraphConv(64, hidden, aggr='mean') self.conv2 = GraphConv(hidden, 32, aggr='mean') self.conv3 = GraphConv(32, 16, aggr='mean') self.lin1 = Linear(16, 16) self.pool1 = EdgePoolingMod(16) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.conv3.reset_parameters() self.lin1.reset_parameters() def forward(self, data): #Data(edge_index=[2, 210], neg_edge_index=[2, 182], pos_edge_index=[2, 28], x=[15, 512], x_bbox=[15, 4], x_heading=[15, 2], x_img_pos=[15, 2], x_pos=[15, 2], y=[210]) if self.training == True: x, geo, reg, edge_index, edge_y = data.x.cuda(), data.geos.cuda( ), data.regressions.cuda(), data.edge_index.cuda(), data.y.cuda() if self.geo == True: # x = torch.cat((x,reg),1) x = torch.cat((x, geo), 1) else: x, reg, edge_index = data.x.cuda(), data.regressions.cuda( ), data.edge_index.cuda() x = F.relu(self.conv1(x, edge_index)) x = F.relu(self.conv2(x, edge_index)) x = F.relu(self.conv3(x, edge_index)) x = x.view(x.size()[0], -1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.2, training=self.training) x, edge_index, batch, edge_scores = self.pool1(x, edge_index, batch=None) return edge_scores def __repr__(self): return self.__class__.__name__
class ASAP(torch.nn.Module): def __init__(self, num_vocab, max_seq_len, node_encoder, emb_dim, num_layers, hidden, ratio=0.8, dropout=0, num_class=0): super(ASAP, self).__init__() self.num_class = num_class self.max_seq_len = max_seq_len self.node_encoder = node_encoder self.conv1 = GraphConv(emb_dim, hidden, aggr='mean') self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() self.convs.extend([ GraphConv(hidden, hidden, aggr='mean') for i in range(num_layers - 1) ]) self.pools.extend([ ASAPooling(hidden, ratio, dropout=dropout) for i in range((num_layers) // 2) ]) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) # self.lin2 = Linear(hidden, dataset.num_classes) if self.num_class > 0: # classification self.graph_pred_linear = torch.nn.Linear(hidden, self.num_class) else: self.graph_pred_linear_list = torch.nn.ModuleList() for i in range(max_seq_len): self.graph_pred_linear_list.append(torch.nn.Linear(hidden, num_vocab)) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, node_depth, batch = data.x, data.edge_index, data.node_depth, data.batch x = self.node_encoder(x, node_depth.view(-1, )) edge_weight = None x = F.relu(self.conv1(x, edge_index)) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = conv(x=x, edge_index=edge_index, edge_weight=edge_weight) x = F.relu(x) xs += [global_mean_pool(x, batch)] if i % 2 == 0 and i < len(self.convs) - 1: pool = self.pools[i // 2] x, edge_index, edge_weight, batch, _ = pool( x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch) x = self.jump(xs) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) # x = self.lin2(x) # return F.log_softmax(x, dim=-1) if self.num_class > 0: return self.graph_pred_linear(x) pred_list = [] for i in range(self.max_seq_len): pred_list.append(self.graph_pred_linear_list[i](x)) return pred_list def __repr__(self): return self.__class__.__name__
class SAGPooling(torch.nn.Module): r"""The self-attention pooling operator from the `"Self-Attention Graph Pooling" <https://arxiv.org/abs/1904.08082>`_ paper .. math:: \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A}) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where nodes are dropped based on a learnable projection score :math:`\mathbf{p}`. Projections scores are learned based on a graph neural network layer. Args: in_channels (int): Size of each input sample. ratio (float): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`. (default: :obj:`0.5`) gnn (string, optional): Specifies which graph neural network layer to use for calculating projection scores (one of :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__(self, in_channels, ratio=0.5, gnn='gConv', **kwargs): super(SAGPooling, self).__init__() self.in_channels = in_channels self.ratio = ratio self.gnn_name = gnn assert gnn in ['GCN', 'GAT', 'SAGE', 'gConv'] if gnn == 'GCN': self.gnn = GCNConv(self.in_channels, 1, **kwargs) elif gnn == 'GAT': self.gnn = GATConv(self.in_channels, 1, **kwargs) elif gnn == 'SAGE': self.gnn = SAGEConv(self.in_channels, 1, **kwargs) else: self.gnn = GraphConv(self.in_channels, 1, **kwargs) self.reset_parameters() def reset_parameters(self): self.gnn.reset_parameters() def forward(self, x, edge_index, edge_attr=None, batch=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x score = torch.tanh(self.gnn(x, edge_index).view(-1)) perm = topk(score, self.ratio, batch) x = x[perm] * score[perm].view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm def __repr__(self): return '{}({}, {}, ratio={})'.format(self.__class__.__name__, self.gnn_name, self.in_channels, self.ratio)
class StarPooling(torch.nn.Module): r"""The edge pooling operator from the `"Towards Graph Pooling by Edge Contraction" <https://graphreason.github.io/papers/17.pdf>`_ and `"Edge Contraction Pooling for Graph Neural Networks" <https://arxiv.org/abs/1905.10990>`_ papers. In short, a score is computed for each edge. Edges are contracted iteratively according to that score unless one of their nodes has already been part of a contracted edge. To duplicate the configuration from the "Towards Graph Pooling by Edge Contraction" paper, use either :func:`EdgePooling.compute_edge_score_softmax` or :func:`EdgePooling.compute_edge_score_tanh`, and set :obj:`add_to_edge_score` to :obj:`0`. To duplicate the configuration from the "Edge Contraction Pooling for Graph Neural Networks" paper, set :obj:`dropout` to :obj:`0.2`. Args: in_channels (int): Size of each input sample. edge_score_method (function, optional): The function to apply to compute the edge score from raw edge scores. By default, this is the softmax over all incoming edges for each node. This function takes in a :obj:`raw_edge_score` tensor of shape :obj:`[num_nodes]`, an :obj:`edge_index` tensor and the number of nodes :obj:`num_nodes`, and produces a new tensor of the same size as :obj:`raw_edge_score` describing normalized edge scores. Included functions are :func:`EdgePooling.compute_edge_score_softmax`, :func:`EdgePooling.compute_edge_score_tanh`, and :func:`EdgePooling.compute_edge_score_sigmoid`. (default: :func:`EdgePooling.compute_edge_score_softmax`) dropout (float, optional): The probability with which to drop edge scores during training. (default: :obj:`0`) add_to_edge_score (float, optional): This is added to each computed edge score. Adding this greatly helps with unpool stability. (default: :obj:`0.5`) """ unpool_description = namedtuple("UnpoolDescription", ["edge_index", "cluster", "batch"]) def __init__(self, in_channels, node_score_method=None, dropout=0): super(StarPooling, self).__init__() self.in_channels = in_channels if node_score_method is None: node_score_method = self.compute_node_score_tanh self.compute_node_score = node_score_method self.dropout = dropout self.score_func = GraphConv(in_channels, 1) self.reset_parameters() def reset_parameters(self): self.score_func.reset_parameters() # @staticmethod # def compute_edge_score_softmax(raw_edge_score, edge_index, num_nodes): # return softmax(raw_edge_score, edge_index[1], num_nodes) @staticmethod def compute_node_score_tanh(raw_edge_score): return torch.tanh(raw_edge_score) @staticmethod def compute_node_score_sigmoid(raw_edge_score): return torch.sigmoid(raw_edge_score) def forward(self, x, edge_index, edge_attr=None, batch=None): r"""Forward computation which computes the raw edge score, normalizes it, and merges the edges. Args: x (Tensor): The node features. edge_index (LongTensor): The edge indices. batch (LongTensor): Batch vector :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each node to a specific example. Return types: * **x** *(Tensor)* - The pooled node features. * **edge_index** *(LongTensor)* - The coarsened edge indices. * **batch** *(LongTensor)* - The coarsened batch vector. * **unpool_info** *(unpool_description)* - Information that is consumed by :func:`EdgePooling.unpool` for unpooling. """ # e = torch.cat([x[edge_index[0]], x[edge_index[1]]], dim=-1) # e = self.lin(e).view(-1) # e = F.dropout(e, p=self.dropout, training=self.training) # e = self.compute_edge_score(e) # e = e + self.add_to_edge_score # TODO: change linear to consider both node and edge features # n = self.lin(x).view(-1) n = self.score_func(x, edge_index).view(-1) n = F.dropout(n, p=self.dropout, training=self.training) n = self.compute_node_score(n) x, edge_index, edge_attr, batch, unpool_info, perm = self.__merge_stars_with_attr_gpu2__( x, edge_index, batch, edge_attr, n) # print(perm) edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=n.size(0)) return x, edge_index, edge_attr, batch, perm def __merge_stars_with_attr_gpu2__(self, x, edge_index, batch, edge_attr, node_score): node_argsort = torch.argsort(node_score, descending=True) cluster = torch.empty_like(batch, device=torch.device('cpu')) nodes_remain = torch.ones_like(batch, device=torch.device('cpu'), dtype=torch.bool) # Iterate through all edges, selecting it if it is not incident to another already chosen edge. edge_index_cpu = edge_index.cpu() i = 0 print("edge index 0", edge_index_cpu[0]) degrees = degree(edge_index_cpu[0]).long() cum_num_nodes = torch.cat( [degrees.new_zeros(1), degrees.cumsum(dim=0)[:-1]], dim=0).long() center_nodes = set() print(degrees) print(edge_index_cpu.size()) print(cum_num_nodes) for node_idx in node_argsort.tolist(): if not nodes_remain[node_idx]: continue dests = edge_index_cpu[1][cum_num_nodes[node_idx].item( ):cum_num_nodes[node_idx].item() + degrees[node_idx].item()] nodes_remain[dests] = False nodes_remain[node_idx] = False # add node_idx to center_nodes center_nodes.add(node_idx) cluster[node_idx] = i cluster[dests] = i i += 1 cluster = cluster.to(x.device) new_x = scatter_add(x, cluster, dim=0, dim_size=i) N = new_x.size(0) print(cluster) new_edge_index, new_edge_attr = coalesce(cluster[edge_index], edge_attr, N, N) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = self.unpool_description(edge_index=edge_index, cluster=cluster, batch=batch) perm = sorted(center_nodes) perm = torch.from_numpy(np.array(perm)).view(-1).to(x.device) return new_x, new_edge_index, new_edge_attr, new_batch, unpool_info, perm def __merge_stars_with_attr_gpu__(self, x, edge_index, edge_attr, batch, node_score): device = x.device nodes_remaining = set(range(x.size(0))) node_argsort = torch.argsort(node_score, descending=True) cluster = torch.empty_like(batch, device=torch.device('cpu')) # Iterate through all edges, selecting it if it is not incident to another already chosen edge. edge_index_cpu = edge_index.cpu() center_nodes = set() i = 0 for node_idx in node_argsort.tolist(): if node_idx not in nodes_remaining: continue dest_bool = edge_index_cpu[0] == node_idx # get the connected nodes dests = set(edge_index_cpu[1][dest_bool].numpy()) # remove the previous combined nodes dests.difference_update(center_nodes) nodes_remaining.difference_update(dests) nodes_remaining.remove(node_idx) # add node_idx to center_nodes center_nodes.add(node_idx) cluster[node_idx] = i cluster[list(dests)] = i i += 1 # The remaining nodes are simply kept. for node_idx in nodes_remaining: cluster[node_idx] = i i += 1 cluster = cluster.to(x.device) new_x = scatter_add(x, cluster, dim=0, dim_size=i) N = new_x.size(0) new_edge_index, new_edge_attr = coalesce(cluster[edge_index], edge_attr, N, N) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = self.unpool_description(edge_index=edge_index, cluster=cluster, batch=batch) perm = sorted(center_nodes) perm = torch.from_numpy(np.array(perm)).view(-1).to(device) return new_x, new_edge_index, new_edge_attr, new_batch, unpool_info, perm def unpool(self, x, unpool_info): r"""Unpools a previous edge pooling step. For unpooling, :obj:`x` should be of same shape as those produced by this layer's :func:`forward` function. Then, it will produce an unpooled :obj:`x` in addition to :obj:`edge_index` and :obj:`batch`. Args: x (Tensor): The node features. unpool_info (unpool_description): Information that has been produced by :func:`EdgePooling.forward`. Return types: * **x** *(Tensor)* - The unpooled node features. * **edge_index** *(LongTensor)* - The new edge indices. * **batch** *(LongTensor)* - The new batch vector. """ new_x = x / unpool_info.new_edge_score.view(-1, 1) new_x = new_x[unpool_info.cluster] return new_x, unpool_info.edge_index, unpool_info.batch def __merge_star_nodes__(self, x, edge_index, batch, node_score): """ Copy from Edge Contraction Pooling :param x: node feature :param edge_index: edge index :param batch: batch index :param edge_score: edge score tensor :return: """ nodes_remaining = set(range(x.size(0))) cluster = torch.empty_like(batch, device=torch.device('cpu')) # edge_argsort = torch.argsort(edge_score, descending=True) node_argsort = torch.argsort(node_score, descending=True) edge_index_cpu = edge_index.cpu() deg = tg.utils.degree(edge_index_cpu[0], x.size(0)) # Iterate through all edges, selecting it if it is not incident to # another already chosen edge. i = 0 new_edge_indices = [] edge_index_cpu = edge_index.cpu() for node_idx in node_argsort.tolist(): # check if the node is still in the nodes_remaining if node_idx not in nodes_remaining: continue dest_bool = edge_index_cpu[0] == node_idx dests = set(edge_index_cpu[1][dest_bool].numpy()) dests.difference_update(nodes_remaining) if len(dests) == 0: continue cluster[list(dests)] = i nodes_remaining.remove(node_idx) nodes_remaining.difference_update(dests) i += 1 cluster = cluster.to(x.device) # We compute the new features as an addition of the old ones. new_x = scatter_add(x, cluster, dim=0, dim_size=i) new_edge_score = edge_score[new_edge_indices] if len(nodes_remaining) > 0: remaining_score = x.new_ones( (new_x.size(0) - len(new_edge_indices), )) new_edge_score = torch.cat([new_edge_score, remaining_score]) new_x = new_x * new_edge_score.view(-1, 1) N = new_x.size(0) new_edge_index, _ = coalesce(cluster[edge_index], None, N, N) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = self.unpool_description(edge_index=edge_index, cluster=cluster, batch=batch, new_edge_score=new_edge_score) return new_x, new_edge_index, new_batch, unpool_info def __merge_edges__(self, x, edge_index, batch, edge_score): """ Copy from Edge Contraction Pooling :param x: node feature :param edge_index: edge index :param batch: batch index :param edge_score: edge score tensor :return: """ nodes_remaining = set(range(x.size(0))) cluster = torch.empty_like(batch, device=torch.device('cpu')) edge_argsort = torch.argsort(edge_score, descending=True) # Iterate through all edges, selecting it if it is not incident to # another already chosen edge. i = 0 new_edge_indices = [] edge_index_cpu = edge_index.cpu() for edge_idx in edge_argsort.tolist(): source = edge_index_cpu[0, edge_idx].item() if source not in nodes_remaining: continue target = edge_index_cpu[1, edge_idx].item() if target not in nodes_remaining: continue new_edge_indices.append(edge_idx) cluster[source] = i nodes_remaining.remove(source) if source != target: cluster[target] = i nodes_remaining.remove(target) i += 1 # The remaining nodes are simply kept. for node_idx in nodes_remaining: cluster[node_idx] = i i += 1 cluster = cluster.to(x.device) # We compute the new features as an addition of the old ones. new_x = scatter_add(x, cluster, dim=0, dim_size=i) new_edge_score = edge_score[new_edge_indices] if len(nodes_remaining) > 0: remaining_score = x.new_ones( (new_x.size(0) - len(new_edge_indices), )) new_edge_score = torch.cat([new_edge_score, remaining_score]) new_x = new_x * new_edge_score.view(-1, 1) N = new_x.size(0) new_edge_index, _ = coalesce(cluster[edge_index], None, N, N) new_batch = x.new_empty(new_x.size(0), dtype=torch.long) new_batch = new_batch.scatter_(0, cluster, batch) unpool_info = self.unpool_description(edge_index=edge_index, cluster=cluster, batch=batch, new_edge_score=new_edge_score) return new_x, new_edge_index, new_batch, unpool_info def __repr__(self): return '{}({})'.format(self.__class__.__name__, self.in_channels)
class SAGPooling(torch.nn.Module): r"""The self-attention pooling operator from the `"Self-Attention Graph Pooling" <https://arxiv.org/abs/1904.08082>`_ and `"Understanding Attention and Generalization in Graph Neural Networks" <https://arxiv.org/abs/1905.02850>`_ papers if min_score :math:`\tilde{\alpha}` is None: .. math:: \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A}) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}} if min_score :math:`\tilde{\alpha}` is a value in [0, 1]: .. math:: \mathbf{y} &= \mathrm{softmax}(\textrm{GNN}(\mathbf{X},\mathbf{A})) \mathbf{i} &= \mathbf{y}_i > \tilde{\alpha} \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathbf{y})_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where nodes are dropped based on a learnable projection score :math:`\mathbf{p}`. Projections scores are learned based on a graph neural network layer. Args: in_channels (int): Size of each input sample. ratio (float): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`. This value is ignored if min_score is not None. (default: :obj:`0.5`) gnn (string, optional): Specifies which graph neural network layer to use for calculating projection scores (one of :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`) min_score (float, optional): Minimal node score :math:`\tilde{\alpha}` which is used to compute indices of pooled nodes :math:`\mathbf{i} = \mathbf{y}_i > \tilde{\alpha}`. When this value is not :obj:`None`, the :obj:`ratio` argument is ignored. (default: :obj:`None`) multiplier (float, optional): Coefficient by which features gets multiplied after pooling. This can be useful for large graphs and when :obj:`min_score` is used. (default: :obj:`1`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__(self, in_channels, ratio=0.5, gnn='GraphConv', min_score=None, multiplier=1, **kwargs): super(SAGPooling, self).__init__() self.in_channels = in_channels self.ratio = ratio self.min_score = min_score self.multiplier = multiplier self.gnn_name = gnn assert gnn in ['GraphConv', 'GCN', 'GAT', 'SAGE'] if gnn == 'GCN': self.gnn = GCNConv(self.in_channels, 1, **kwargs) elif gnn == 'GAT': self.gnn = GATConv(self.in_channels, 1, **kwargs) elif gnn == 'SAGE': self.gnn = SAGEConv(self.in_channels, 1, **kwargs) else: self.gnn = GraphConv(self.in_channels, 1, **kwargs) self.reset_parameters() def reset_parameters(self): self.gnn.reset_parameters() def forward(self, x, edge_index, edge_attr=None, batch=None, attn=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) attn = x if attn is None else attn attn = attn.unsqueeze(-1) if attn.dim() == 1 else attn score = self.gnn(attn, edge_index).view(-1) if self.min_score is None: score = torch.tanh(score) else: score = softmax(score, batch) perm = topk(score, self.ratio, batch, self.min_score) x = x[perm] * score[perm].view(-1, 1) x = self.multiplier * x if self.multiplier != 1 else x batch = batch[perm] edge_index, edge_attr = filter_adj(edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm, score[perm] def __repr__(self): return '{}({}, {}, {}={}, multiplier={})'.format( self.__class__.__name__, self.gnn_name, self.in_channels, 'ratio' if self.min_score is None else 'min_score', self.ratio if self.min_score is None else self.min_score, self.multiplier)
class GraphNN(torch.nn.Module): def __init__(self, num_layers, num_input_features, hidden): super(GraphNN, self).__init__() self.conv1 = GraphConv(num_input_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden)) self.lin1 = torch.nn.Linear(3 * hidden, hidden) self.lin2 = torch.nn.Linear(hidden, 2) def reset_parameters(self): # reset all conv and linear layers self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters( ) # .reset_parameters() is method of the torch_geometric.nn.GraphConv class self.lin1.reset_parameters( ) # .reset_parameters() is method of the torch.nn.Linear class self.lin2.reset_parameters() def forward(self, data): # data: Batch(batch=[num_nodes_in_batch], # edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge], # edge_index=[2,2*num_nodes_in_batch], # pos=[num_nodes_in_batch,2], # x=[num_nodes_in_batch, num_input_features_per_node], # y=[num_graphs_in_batch, num_classes] # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2] x, edge_index, batch = data.x, data.edge_index, data.batch # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node]) # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch]) # batch.shape: torch.Size([num_nodes_in_batch]) # example: x.shape = troch.Size([2490,33]) # edge_index.shape = torch.Size([2,4980]) # batch.shape = torch.Size([2490]) # graph convolutions and relu activation x = F.relu(self.conv1(x, edge_index)) # x.shape: torch.Size([num_nodes_in_batch, hidden]) # example: x.shape = torch.Size([2490, 66]) for conv in self.convs: x = F.relu(conv(x, edge_index)) # x.shape: torch.Size([num_nodes_in_batch, hidden]) # example: x.shape = torch.Size([2490, 66]) x = torch.cat([ global_add_pool(x, batch), global_mean_pool(x, batch), global_max_pool(x, batch) ], dim=1) # x.shape: torch.Size([num_graphs_in_batch, 3*hidden) # example: x.shape = torch.Size([32, 3*66]) # linear layers, activation function, dropout x = F.relu(self.lin1(x)) # x.shape: torch.Size([num_graphs_in_batch, hidden) # example: x.shape = torch.Size([32, 66]) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) # x.shape: torch.Size([num_graphs_in_batch, num_classes) # example: x.shape = torch.Size([32, 2]) output = F.log_softmax(x, dim=-1) return output def __repr__(self): # for getting a printable representation of an object return self.__class__.__name__
class NodeImportance(torch.nn.Module): def __init__(self, in_channels, ratio=0.5, layer=1, gnn='GCN', bias=True, **kwargs): super(NodeImportance, self).__init__() self.in_channels = in_channels self.ratio = ratio self.layer = layer assert gnn in ['GCN', 'GAT', 'SAGE'] if gnn == 'GCN': if layer == 1: self.gnn = GraphConv(self.in_channels, 1, **kwargs) elif layer == 2: self.gnn1 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn2 = GraphConv(self.in_channels, 1, **kwargs) elif layer == 3: self.gnn1 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn2 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn3 = GraphConv(self.in_channels, 1, **kwargs) elif gnn == 'GAT': self.gnn = GATConv(self.in_channels, 1, **kwargs) else: self.gnn = SAGEConv(self.in_channels, 1, **kwargs) self.weight_closeness = Parameter(torch.Tensor(1)) self.weight_degree = Parameter(torch.Tensor(1)) self.weight_score = Parameter(torch.Tensor(1)) if bias: self.bias = Parameter(torch.Tensor(1)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): if self.layer == 1: self.gnn.reset_parameters() elif self.layer == 2: self.gnn1.reset_parameters() self.gnn2.reset_parameters() elif self.layer == 3: self.gnn1.reset_parameters() self.gnn2.reset_parameters() self.gnn3.reset_parameters() uniform_(self.weight_closeness, a=0, b=1) uniform_(self.weight_degree, a=0, b=1) uniform_(self.bias, a=0, b=1) uniform_(self.weight_score, a=0, b=1) def forward(self, x, edge_index, closeness, degree, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x if self.layer == 1: score = torch.relu(self.gnn(x, edge_index).view(-1)) elif self.layer == 2: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index).view(-1)) elif self.layer == 3: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index)) score = torch.relu(self.gnn3(score, edge_index).view(-1)) '''centrality adjust''' closeness = closeness * self.weight_closeness degree = degree * self.weight_degree centrality = closeness + degree if self.bias is not None: centrality += self.bias score = score * self.weight_score score = score + centrality score = F.relu(score) perm = topk(score, self.ratio, batch) tmp1 = x[perm] tmp2 = score[perm] x = tmp1 * tmp2.view(-1, 1) batch = batch[perm] return x, perm, batch def __repr__(self): return '{}({}, {}, ratio={})'.format(self.__class__.__name__, self.gnn_name, self.in_channels, self.ratio)
class Graclus(torch.nn.Module): def __init__(self, num_features, num_classes, num_layers, hidden, pooling_type, no_cat=False, encode_edge=False): super(Graclus, self).__init__() self.encode_edge = encode_edge if encode_edge: self.conv1 = GCNConv(hidden, aggr='add') else: self.conv1 = GraphConv(num_features, hidden, aggr='add') self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GraphConv(hidden, hidden, aggr='add')) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) if no_cat: self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, num_classes) self.pooling_type = pooling_type self.no_cat = no_cat self.atom_encoder = AtomEncoder(emb_dim=hidden) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch if self.encode_edge: x = self.atom_encoder(x) x = self.conv1(x, edge_index, data.edge_attr) else: x = self.conv1(x, edge_index) x = F.relu(x) xs = [global_mean_pool(x, batch)] for i, conv in enumerate(self.convs): x = F.relu(conv(x, edge_index)) xs += [global_mean_pool(x, batch)] if self.pooling_type != 'none': if self.pooling_type == 'complement': complement = batched_negative_edges(edge_index=edge_index, batch=batch, force_undirected=True) cluster = graclus(complement, num_nodes=x.size(0)) elif self.pooling_type == 'graclus': cluster = graclus(edge_index, num_nodes=x.size(0)) data = Batch(x=x, edge_index=edge_index, batch=batch) data = max_pool(cluster, data) x, edge_index, batch = data.x, data.edge_index, data.batch if not self.no_cat: x = self.jump(xs) else: x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = self.lin2(x) return x