class SortPool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, num_classes): super(SortPool, self).__init__() self.k = 10 self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.lin1 = nn.Linear(self.k * hidden, hidden) self.lin2 = nn.Linear(hidden, num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_sort_pool(x, batch, self.k) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GlobalAttentionNet(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super().__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.att = GlobalAttention(Linear(hidden, 1)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.att.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.att(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class SortPool(torch.nn.Module): def __init__(self, in_channels, hidden_channels, output_dim, num_layers): super(SortPool, self).__init__() self.k = 30 self.conv1 = SAGEConv(in_channels, hidden_channels) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden_channels, hidden_channels)) self.conv1d = Conv1d(hidden_channels, 32, 5) self.lin1 = Linear(32 * (self.k - 5 + 1), hidden_channels) self.lin2 = Linear(hidden_channels, output_dim) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.conv1d.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, x, edge_index, batch): # x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_sort_pool(x, batch, self.k) x = x.view(len(x), self.k, -1).permute(0, 2, 1) x = F.relu(self.conv1d(x)) x = x.view(len(x), -1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x
class GraphSAGEWithJK(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(GraphSAGEWithJK, self).__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.jump = JumpingKnowledge(mode='cat') self.lin1 = Linear(num_layers * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.jump.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) xs = [x] for conv in self.convs: x = F.relu(conv(x, edge_index)) xs += [x] x = self.jump(xs) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GraphSAGE(torch.nn.Module): def __init__(self, num_layers=2, hidden=16, features_num=16, num_class=2): super().__init__() self.sage1 = SAGEConv(features_num, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.lin2 = Linear(hidden, num_class) def reset_parameters(self): self.first_lin.reset_parameters() self.sage1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = F.relu(self.sage1(x, edge_index, edge_weight=edge_weight)) x = F.dropout(x, p=0.5, training=self.training) for conv in self.convs: x = F.relu(conv(x, edge_index, edge_weight=edge_weight)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class SAGE(torch.nn.Module): #已精调 def __init__(self, num_layers=2, hidden=32, features_num=32, num_class=2): super(SAGE, self).__init__() self.conv1 = SAGEConv(hidden, hidden) self.conv2 = SAGEConv(hidden, hidden) self.out = Linear(hidden * 3, num_class) self.first_lin = Linear(features_num, hidden) self.fuse_weight = torch.nn.Parameter(torch.FloatTensor(num_layers),requires_grad=True) self.fuse_weight.data.fill_(float(1) / (num_layers + 1)) def reset_parameters(self): self.first_lin.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = F.relu(self.first_lin(x)) x = F.dropout(x, p=0.5, training=self.training) xx = x x = self.conv1(x, edge_index, edge_weight) x = F.dropout(x, p=0.2, training=self.training) xx = torch.cat([xx, x], dim=1) x = self.conv2(x, edge_index, edge_weight) x = F.dropout(x, p=0.2, training=self.training) xx = torch.cat([xx, x], dim=1) x = self.out(xx) return F.log_softmax(x, dim=-1)
class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() hidden = args.hidden num_layers = 5 self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.set2set = Set2Set(hidden, processing_steps=4) self.lin1 = Linear(2 * hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.set2set.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = self.set2set(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
class GraphSAGE(torch.nn.Module): def __init__(self, num_features, output_channels, num_layers=3, hidden=128, **kwargs): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, output_channels) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, target_size, **kwargs): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_mean_pool(x, batch, size=target_size) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class Net(torch.nn.Module): def __init__(self): super().__init__() hidden = args.hidden num_layers = 5 self.k = 30 self.conv1 = SAGEConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) self.conv1d = Conv1d(hidden, 32, 5) self.lin1 = Linear(32 * (self.k - 5 + 1), hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.conv1d.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_sort_pool(x, batch, self.k) x = x.view(len(x), self.k, -1).permute(0, 2, 1) x = F.relu(self.conv1d(x)) x = x.view(len(x), -1) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1)
def reset_parameters(self): for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters()
class SAGEEncoder(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.conv = SAGEConv(in_dim, out_dim) self.sigma = nn.PReLU(out_dim) self.reset_parameters() def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index): z = self.sigma(self.conv(x, edge_index)) return z
class GNNBlock(torch.nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.norm = LayerNorm(in_channels, elementwise_affine=True) self.conv = SAGEConv(in_channels, out_channels) def reset_parameters(self): self.norm.reset_parameters() self.conv.reset_parameters() def forward(self, x, edge_index, dropout_mask=None): x = self.norm(x).relu() if self.training and dropout_mask is not None: x = x * dropout_mask return self.conv(x, edge_index)
class BenSAGE(torch.nn.Module): def __init__(self, num_layers=2, hidden=16, features_num=16, num_class=2): super().__init__() # first layer self.conv1 = SAGEConv(features_num, hidden) # list of 2nd - num_layers layers self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) # fully connected layers self.lin2 = Linear(hidden, num_class) self.first_lin = Linear(features_num, hidden) def reset_parameters(self): # clear weights self.first_lin.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight # fully connected layer + relu x = F.relu(self.first_lin(x)) # dropout layer x = F.dropout(x, p=0.5, training=self.training) # GCN layers for conv in self.convs: x = F.relu(conv(x, edge_index, edge_weight=edge_weight)) # Another dropout x = F.dropout(x, p=0.5, training=self.training) # second FC layer x = self.lin2(x) # Softmax return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GraphSAGE(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(dataset.num_features, hidden) #self.conv1 = SAGELafConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 2): #self.convs.append(SAGEConv(hidden, hidden)) self.convs.append(SAGELafConv(hidden, hidden)) #self.convn = SAGELafConv(hidden, dataset.num_classes) self.convn = SAGEConv(hidden, dataset.num_classes) #self.lin1 = Linear(hidden, hidden) #self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.convn.reset_parameters() #self.lin1.reset_parameters() #self.lin2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = F.relu(self.convn(x, edge_index)) #x = F.relu(self.lin1(x)) #x = F.dropout(x, p=0.5, training=self.training) #x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class SAGPooling(torch.nn.Module): r"""The self-attention pooling operator from the `"Self-Attention Graph Pooling" <https://arxiv.org/abs/1904.08082>`_ paper .. math:: \mathbf{y} &= \textrm{GNN}(\mathbf{X}, \mathbf{A}) \mathbf{i} &= \mathrm{top}_k(\mathbf{y}) \mathbf{X}^{\prime} &= (\mathbf{X} \odot \mathrm{tanh}(\mathbf{y}))_{\mathbf{i}} \mathbf{A}^{\prime} &= \mathbf{A}_{\mathbf{i},\mathbf{i}}, where nodes are dropped based on a learnable projection score :math:`\mathbf{p}`. Projections scores are learned based on a graph neural network layer. Args: in_channels (int): Size of each input sample. ratio (float): Graph pooling ratio, which is used to compute :math:`k = \lceil \mathrm{ratio} \cdot N \rceil`. (default: :obj:`0.5`) gnn (string, optional): Specifies which graph neural network layer to use for calculating projection scores (one of :obj:`"GCN"`, :obj:`"GAT"` or :obj:`"SAGE"`). (default: :obj:`GCN`) **kwargs (optional): Additional parameters for initializing the graph neural network layer. """ def __init__(self, in_channels, ratio=0.5, gnn='GCN', **kwargs): super(SAGPooling, self).__init__() self.in_channels = in_channels self.ratio = ratio self.gnn_name = gnn assert gnn in ['GCN', 'GAT', 'SAGE'] if gnn == 'GCN': self.gnn = GraphConv(self.in_channels, 1, **kwargs) elif gnn == 'GAT': self.gnn = GATConv(self.in_channels, 1, **kwargs) else: self.gnn = SAGEConv(self.in_channels, 1, **kwargs) self.reset_parameters() def reset_parameters(self): self.gnn.reset_parameters() def forward(self, x, edge_index, edge_attr=None, batch=None): """""" if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x score = torch.tanh(self.gnn(x, edge_index).view(-1)) perm = topk(score, self.ratio, batch) x = x[perm] * score[perm].view(-1, 1) batch = batch[perm] edge_index, edge_attr = filter_adj( edge_index, edge_attr, perm, num_nodes=score.size(0)) return x, edge_index, edge_attr, batch, perm def __repr__(self): return '{}({}, {}, ratio={})'.format(self.__class__.__name__, self.gnn_name, self.in_channels, self.ratio)
class GraphSAGE(torch.nn.Module): def __init__(self, args): super(GraphSAGE, self).__init__() self.args = set_default(args, { 'num_layers': 2, 'hidden': 64, 'hidden2': 32, 'dropout': 0.5, 'lr': 0.005, 'epoches': 300, 'weight_decay': 5e-4, 'act': 'leaky_relu', 'withbn': True, }) self.timer = self.args['timer'] self.dropout = self.args['dropout'] self.agg = self.args['agg'] self.withbn = self.args['withbn'] self.conv1 = SAGEConv(self.args['hidden'], self.args['hidden']) self.convs = torch.nn.ModuleList() if self.withbn: self.bn1 = BatchNorm1d(self.args['hidden']) self.bns = torch.nn.ModuleList() hd = [self.args['hidden'], self.args['hidden']] for i in range(self.args['num_layers'] - 1): hd.append(self.args['hidden2']) self.convs.append(SAGEConv(self.args['hidden'], self.args['hidden2'])) self.bns.append(BatchNorm1d(self.args['hidden2'])) if self.args['agg'] == 'concat': outdim = sum(hd) elif self.args['agg'] == 'self': outdim = hd[-1] if self.args['act'] == 'leaky_relu': self.act = F.leaky_relu elif self.args['act'] == 'tanh': self.act = torch.tanh else: self.act = lambda x: x self.lin2 = Linear(outdim, self.args['num_class']) self.first_lin = Linear(self.args['features_num'], self.args['hidden']) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = self.act(self.first_lin(x)) xs = [x] x = self.act(self.conv1(x, edge_index, edge_weight=edge_weight)) if self.withbn: x = self.bn1(x) x = F.dropout(x, p=self.dropout, training=self.training) xs.append(x) for conv, bn in zip(self.convs, self.bns): x = self.act(conv(x, edge_index, edge_weight=edge_weight)) if self.withbn: x = bn(x) xs.append(x) #x = F.dropout(x, p=self.dropout, training=self.training) if self.agg == 'concat': x = torch.cat(xs, dim=1) elif self.agg == 'self': x = xs[-1] x = self.lin2(x) return F.log_softmax(x, dim=-1) def train_predict(self, data, train_mask=None, val_mask=None, return_out=True): if train_mask is None: train_mask = data.train_mask optimizer = torch.optim.Adam(self.parameters(), lr=self.args['lr'], weight_decay=self.args['weight_decay']) flag_end = False st = time.time() for epoch in range(1, self.args['epoches']): self.train() optimizer.zero_grad() res = self.forward(data) loss = F.nll_loss(res[train_mask], data.y[train_mask]) loss.backward() optimizer.step() if epoch%50 == 0: cost = (time.time()-st)/epoch*50 if max(cost*10, 5) > self.timer.remain_time(): flag_end = True break test_mask = data.test_mask self.eval() with torch.no_grad(): res = self.forward(data) if return_out: pred = res else: pred = res[test_mask] if val_mask is not None: return pred, res[val_mask], flag_end return pred, flag_end def __repr__(self): return self.__class__.__name__
class SAGE(nn.Module): """ SAGE model class Attributes: NN_1 : torch.nn.Linear Input layer NN_2 : torch.nn.Linear Hidden layer NN_3 : torch.nn.Linear Output layer """ def __init__(self, num_in, num_hid, num_out, mathcal_Z, dropout, device): """ Initialization Parameters: num_in : int Number of neurons in the input layer num_hid : int Number of neurons in the hidden layer num_out : int Number of neurons in the output layer mathcal_Z : torch.distributions.multivariate_normal AWGN channel dropout : float Dropout probability device : torch.device Torch device Returns: """ super(SAGE, self).__init__() # Linear layers self.NN_1 = SAGEConv(num_in, num_hid) self.NN_2 = SAGEConv(num_hid, num_hid) self.NN_3 = SAGEConv(num_hid, num_out) # Batch Normalization layers self.BN_1 = nn.BatchNorm1d(num_hid) self.BN_2 = nn.BatchNorm1d(num_hid) self.mathcal_Z = mathcal_Z self.dropout = dropout self.device = device def reset_parameters(self): """ Reset the parameters """ self.NN_1.reset_parameters() self.NN_2.reset_parameters() self.NN_3.reset_parameters() self.BN_1.reset_parameters() self.BN_2.reset_parameters() def forward(self, x, edge_index): """ Forward module of MLP Parameters: x : torch.tensor of shape (num_examples, num_dims) Input tensor edge_index : torch.tensor of shape (2, num_edges) Input edge index Returns: y_pred : torch.tensor of shape (num_examples) Predicted labels [S_1, S_2]: list of length 2 Outputs of the hidden layers before passing through the AWGN channels """ x = self.NN_1(x, edge_index) x = self.BN_1(x) S_1 = F.relu(x) if self.mathcal_Z is None: x = F.dropout(S_1, p=self.dropout, training=self.training) else: n_1 = torch.zeros(S_1.size(0)) mathcal_z = self.mathcal_Z.sample(n_1.size()).to(self.device) T_1 = S_1 + mathcal_z if self.training else S_1 x = self.NN_2(T_1, edge_index) x = self.BN_2(x) S_2 = F.relu(x) if self.mathcal_Z is None: x = F.dropout(S_2, p=self.dropout, training=self.training) else: n_2 = torch.zeros(S_2.size(0)) mathcal_z = self.mathcal_Z.sample(n_2.size()).to(self.device) T_2 = S_2 + mathcal_z if self.training else S_2 z = self.NN_3(T_2, edge_index) y_pred = torch.log_softmax(z, dim=-1) return z.detach(), y_pred, [S_1.detach(), S_2.detach()]
class NodeImportance(torch.nn.Module): def __init__(self, in_channels, ratio=0.5, layer=1, gnn='GCN', bias=True, **kwargs): super(NodeImportance, self).__init__() self.in_channels = in_channels self.ratio = ratio self.layer = layer assert gnn in ['GCN', 'GAT', 'SAGE'] if gnn == 'GCN': if layer == 1: self.gnn = GraphConv(self.in_channels, 1, **kwargs) elif layer == 2: self.gnn1 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn2 = GraphConv(self.in_channels, 1, **kwargs) elif layer == 3: self.gnn1 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn2 = GraphConv(self.in_channels, self.in_channels, **kwargs) self.gnn3 = GraphConv(self.in_channels, 1, **kwargs) elif gnn == 'GAT': self.gnn = GATConv(self.in_channels, 1, **kwargs) else: self.gnn = SAGEConv(self.in_channels, 1, **kwargs) self.weight_closeness = Parameter(torch.Tensor(1)) self.weight_degree = Parameter(torch.Tensor(1)) self.weight_score = Parameter(torch.Tensor(1)) if bias: self.bias = Parameter(torch.Tensor(1)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): if self.layer == 1: self.gnn.reset_parameters() elif self.layer == 2: self.gnn1.reset_parameters() self.gnn2.reset_parameters() elif self.layer == 3: self.gnn1.reset_parameters() self.gnn2.reset_parameters() self.gnn3.reset_parameters() uniform_(self.weight_closeness, a=0, b=1) uniform_(self.weight_degree, a=0, b=1) uniform_(self.bias, a=0, b=1) uniform_(self.weight_score, a=0, b=1) def forward(self, x, edge_index, closeness, degree, edge_attr=None, batch=None): if batch is None: batch = edge_index.new_zeros(x.size(0)) x = x.unsqueeze(-1) if x.dim() == 1 else x if self.layer == 1: score = torch.relu(self.gnn(x, edge_index).view(-1)) elif self.layer == 2: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index).view(-1)) elif self.layer == 3: score = torch.relu(self.gnn1(x, edge_index)) score = torch.relu(self.gnn2(score, edge_index)) score = torch.relu(self.gnn3(score, edge_index).view(-1)) '''centrality adjust''' closeness = closeness * self.weight_closeness degree = degree * self.weight_degree centrality = closeness + degree if self.bias is not None: centrality += self.bias score = score * self.weight_score score = score + centrality score = F.relu(score) perm = topk(score, self.ratio, batch) tmp1 = x[perm] tmp2 = score[perm] x = tmp1 * tmp2.view(-1, 1) batch = batch[perm] return x, perm, batch def __repr__(self): return '{}({}, {}, ratio={})'.format(self.__class__.__name__, self.gnn_name, self.in_channels, self.ratio)
class GraphSAGE(torch.nn.Module): def __init__(self, num_input_features, num_layers, hidden): super(GraphSAGE, self).__init__() self.conv1 = SAGEConv(num_input_features, hidden) # SAGEConv layer self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(SAGEConv(hidden, hidden)) # SAGEConv layers self.lin1 = Linear(3 * hidden, hidden) # linear layer self.lin2 = Linear(hidden, 2) # linear layer def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): # data: Batch(batch=[num_nodes_in_batch], # edge_attr=[2*num_nodes_in_batch,num_edge_features_per_edge], # edge_index=[2,2*num_nodes_in_batch], # pos=[num_nodes_in_batch,2], # x=[num_nodes_in_batch, num_input_features_per_node], # y=[num_graphs_in_batch, num_classes] # example: Batch(batch=[2490], edge_attr=[4980,1], edge_index=[2,4980], pos=[2490,2], x=[2490,33], y=[32,2] x, edge_index, batch = data.x, data.edge_index, data.batch # x.shape: torch.Size([num_nodes_in_batch, num_input_features_per_node]) # edge_index.shape: torch.Size([2, 2*num_nodes_in_batch]) # batch.shape: torch.Size([num_nodes_in_batch]) # example: x.shape = troch.Size([2490,33]) # edge_index.shape = torch.Size([2,4980]) # batch.shape = torch.Size([2490]) x = F.relu(self.conv1(x, edge_index)) # x.shape: torch.Size([num_nodes_in_batch, hidden]) # example: x.shape = troch.Size([2490,66]) for conv in self.convs: x = F.relu(conv(x, edge_index)) # x.shape: torch.Size([num_nodes_in_batch, hidden]) # example: x.shape = troch.Size([2490,66]) x = torch.cat([ global_add_pool(x, batch), global_mean_pool(x, batch), global_max_pool(x, batch) ], dim=1) # x.shape: torch.Size([num_graphs_in_batch, hidden) # example: x.shape = torch.Size([32, 66]) x = F.relu(self.lin1(x)) # x.shape: torch.Size([num_graphs_in_batch, hidden) # example: x.shape = torch.Size([32, 66]) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) # x.shape: torch.Size([num_graphs_in_batch, num_classes) # example: x.shape = torch.Size([32, 2]) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__