class GCN(torch.nn.Module): def __init__(self, num_features, output_channels, num_layers=3, nb_neurons=128, **kwargs): super(GCN, self).__init__() self.conv1 = GCNConv(num_features, nb_neurons) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(nb_neurons, nb_neurons)) self.lin1 = Linear(nb_neurons, nb_neurons) self.lin2 = Linear(nb_neurons, output_channels) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data, target_size, **kwargs): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_mean_pool(x, batch, size=target_size) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GCN(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(GCN, self).__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) for conv in self.convs: x = F.relu(conv(x, edge_index)) x = global_mean_pool(x, batch) x = F.relu(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, 128) self.conv2 = GCNConv(128, 64) self.score1 = GATScore(Linear(dataset.num_features * 2, 1)) self.score2 = GATScore(Linear(args.hidden * 2, 1)) self.score3 = GATScore(Linear(args.hidden * 2, 1)) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data, pos_edge_index, neg_edge_index): x, edge_index = data.x, data.train_edge_index # 0 layer # s1 = self.score1(x, masked_node) # masked_node = x[data.cold_mask_node] # 1 layer x = F.relu(self.conv1(x, edge_index)) # masked_node = F.relu(self.conv1(masked_node, torch.zeros([2,1], dtype=edge_index.dtype, device= edge_index.device))) # s2 = self.score2(x, masked_node) # 2 layer x = self.conv2(x, edge_index) # masked_node = self.conv2(masked_node, torch.zeros([2,1], dtype=edge_index.dtype, device= edge_index.device)) # s3 = self.score2(x, masked_node) # x[data.cold_mask_node] = masked_node total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) x_j = torch.index_select(x, 0, total_edge_index[0]) x_i = torch.index_select(x, 0, total_edge_index[1]) return torch.einsum("ef,ef->e", x_i, x_j)
def reset_parameters(self): for conv in self.down_convs: conv.reset_parameters() for pool in self.pools: pool.reset_parameters() for conv in self.up_convs: conv.reset_parameters()
class GCNRecsysModel(GraphRecsysModel): def __init__(self, **kwargs): super(GCNRecsysModel, self).__init__(**kwargs) def _init(self, **kwargs): self.if_use_features = kwargs['if_use_features'] self.dropout = kwargs['dropout'] if not self.if_use_features: self.x = torch.nn.Embedding(kwargs['dataset']['num_nodes'], kwargs['emb_dim'], max_norm=1).weight self.edge_index = self.update_graph_input(kwargs['dataset']) else: raise NotImplementedError('Feature not implemented!') self.conv1 = GCNConv(kwargs['emb_dim'], kwargs['hidden_size']) self.conv2 = GCNConv(kwargs['hidden_size'], kwargs['repr_dim']) def reset_parameters(self): if not self.if_use_features: torch.nn.init.uniform_(self.x, -1.0, 1.0) self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) x = F.normalize(x) return x
class GCN(torch.nn.Module): #已精调 def __init__(self, num_layers=2, hidden=32, features_num=32, num_class=2): super(GCN, self).__init__() self.conv1 = GCNConv(hidden, hidden) self.conv2 = GCNConv(hidden, hidden) self.out = Linear(hidden * 3, num_class) self.first_lin = Linear(features_num, hidden) self.fuse_weight = torch.nn.Parameter(torch.FloatTensor(num_layers),requires_grad=True) self.fuse_weight.data.fill_(float(1) / (num_layers + 1)) def reset_parameters(self): self.first_lin.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = F.relu(self.first_lin(x)) x = F.dropout(x, p=0.5, training=self.training) xx = x x = self.conv1(x, edge_index, edge_weight) x = F.dropout(x, p=0.2, training=self.training) xx = torch.cat([xx, x], dim=1) x = self.conv2(x, edge_index, edge_weight) x = F.dropout(x, p=0.2, training=self.training) xx = torch.cat([xx, x], dim=1) x = self.out(xx) return F.log_softmax(x, dim=-1)
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) self.att1 = Parameter(torch.Tensor(args.hidden)) self.att2 = Parameter(torch.Tensor(dataset.num_classes)) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data, pos_edge_index, neg_edge_index, edge_index, masked_nodes): x = data.x # mask_loop = torch.stack((torch.LongTensor(masked_nodes), torch.LongTensor(masked_nodes)), dim=0 ) mask_loop = torch.stack( (torch.LongTensor([0, 1]), torch.LongTensor([2, 3])), dim=0) total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # x_j = torch.index_select(x, 0, total_edge_index[0]) # x_i = torch.index_select(x, 0, total_edge_index[1]) # dist1 = x_j-x_i # o1 = F.softmax(dist1*att1) x = F.relu(self.conv1(x, edge_index)) # LAYER 1 x = self.conv2(x, edge_index) # LAYER 2 return x
class GCN(torch.nn.Module): def __init__(self, dataset, num_layers, hidden): super(GCN, self).__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.lin1 = Linear(hidden, hidden) self.lin2 = Linear(hidden, dataset.num_classes) # print('GCN') def reset_parameters(self): self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, dataX, dataY): activation = F.relu # torch.sigmoid x, edge_index = dataX, dataY x = activation(self.conv1(x, edge_index)) for conv in self.convs: x = activation(conv(x, edge_index)) #x = global_mean_pool(x, batch) x = activation(self.lin1(x)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return x def __repr__(self): return self.__class__.__name__
class GCN(torch.nn.Module): def __init__(self, features_num, num_class, num_layers, hidden): super(GCN, self).__init__() self.conv1 = GCNConv(features_num, hidden) self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.lin2 = Linear(hidden, num_class) self.first_lin = Linear(features_num, hidden) def reset_parameters(self): self.first_lin.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight x = F.relu(self.first_lin(x)) x = F.dropout(x, p=0.5, training=self.training) for conv in self.convs: x = F.relu(conv(x, edge_index, edge_weight=edge_weight)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() num_nodes = dataset[0].num_nodes self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) self.p = torch.nn.Parameter(torch.randn(int(num_nodes * 0.2), num_nodes), requires_grad=True) # test_mask x all_node * all_node x feature # self.score = Linear(dataset.num_classes, 1) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.p.reset_parameters() def forward(self, data, pos_edge_index, neg_edge_index): x, edge_index, masked_nodes = data.x, data.train_edge_index, data.masked_nodes total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) x = F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) # x_mask = torch.index_select(x, 0, masked_nodes) x[masked_nodes] = torch.matmul(p, x) return F.log_softmax(x, dim=1)
class VariationalGraphDecoder(nn.Module): """Acts on NxD node embedding matrix.""" def __init__( self, in_channels: int, hidden_channels: int, out_channels: int, depth: int = 1, sum_res: bool = True, act=F.relu, ): super(VariationalGraphDecoder, self).__init__() assert depth >= 1 self.depth = depth self.sum_res = sum_res self.act = act self.projection_conv = GCNConv(in_channels, hidden_channels, improved=True) self.up_convs = nn.ModuleList() for i in range(depth - 1): self.up_convs.append( GCNConv(hidden_channels, hidden_channels, improved=True) ) self.up_convs.append(GCNConv(hidden_channels, out_channels, improved=True)) self.reset_parameters() def reset_parameters(self): self.projection_conv.reset_parameters() for conv in self.up_convs: conv.reset_parameters() def forward(self, x, edge_index, xs, edge_indices, edge_weights, perms): x = self.projection_conv(x, edge_index) x = self.up_sample(x, xs, edge_indices, edge_weights, perms) return x def up_sample(self, x, xs, edge_indices, edge_weights, perms): for i in range(self.depth): j = self.depth - 1 - i res = xs[j] edge_index = edge_indices[j] edge_weight = edge_weights[j] perm = perms[j] up = torch.zeros_like(res) # print("up.shape:", up.shape) # print("x.shape:", x.shape) up[perm] = x x = res + up if self.sum_res else torch.cat((res, up), dim=-1) x = self.up_convs[i](x, edge_index, edge_weight) x = self.act(x) if i < self.depth - 1 else x return x
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) # self.score = GATScore(Linear(args.hidden*2, 1)) self.score1 = Linear(dataset.num_features, 1) self.score2 = Linear(args.hidden, 1) self.score3 = Linear(dataset.num_classes, 1) self.sum = Linear(3, 1) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.score1.reset_parameters() self.score2.reset_parameters() self.score3.reset_parameters() def forward(self, data, pos_edge_index, neg_edge_index, edge_index): x, masked_nodes = data.x, data.masked_nodes total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) x_j = torch.index_select(x, 0, total_edge_index[0]) x_i = torch.index_select(x, 0, total_edge_index[1]) dist1 = x_j - x_i o1 = F.relu(self.score1(dist1).squeeze()) x = F.relu(self.conv1(x, edge_index)) x_j = torch.index_select(x, 0, total_edge_index[0]) x_i = torch.index_select(x, 0, total_edge_index[1]) dist2 = x_j - x_i o2 = F.relu(self.score2(dist2).squeeze()) # masked_node = F.relu(self.conv1(masked_node, torch.zeros([2,1], dtype=edge_index.dtype, device= edge_index.device))) # s1 = self.score(x, masked_node) # x_j = torch.index_select(x, 0, total_edge_index[0]) # x_i = torch.index_select(x, 0, total_edge_index[1]) # s1 = x_i-x_j # 2 layer x = self.conv2(x, edge_index) x_j = torch.index_select(x, 0, total_edge_index[0]) x_i = torch.index_select(x, 0, total_edge_index[1]) dist3 = x_j - x_i o3 = F.relu(self.score3(dist3).squeeze()) score_loss = torch.matmul( dist1, self.score1.weight.squeeze()).mean() + torch.matmul( dist2, self.score2.weight.squeeze()).mean() + torch.matmul( dist3, self.score3.weight.squeeze()).mean() # out = F.relu(self.sum(torch.cat(o1,o2,o3),0).squeeze()) # return torch.einsum("ef,ef->e", x_i, x_j) return o3, score_loss
class LinearEncoder(nn.Module): def __init__(self, in_channels, out_channels): super(LinearEncoder, self).__init__() self.conv = GCNConv(in_channels, out_channels) def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index): return self.conv(x, edge_index)
class CLS(torch.nn.Module): def __init__(self, d_in, d_out): super(CLS, self).__init__() self.conv = GCNConv(d_in, d_out, cached=False) def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index, mask=None): x = self.conv(x, edge_index) x = F.log_softmax(x, dim=1) return x
class GCNEncoder(nn.Module): def __init__(self, in_channels, out_channels): super(GCNEncoder, self).__init__() self.conv1 = GCNConv(in_channels, 2 * out_channels) self.conv2 = GCNConv(2 * out_channels, out_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() return self.conv2(x, edge_index)
class CRD(torch.nn.Module): def __init__(self, d_in, d_out, p): super(CRD, self).__init__() self.conv = GCNConv(d_in, d_out, cached=True) self.p = p def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index, mask=False): x = F.relu(self.conv(x, edge_index)) x = F.dropout(x, p=self.p, training=self.training) return x
class GCNEncoder(torch.nn.Module): def __init__(self, in_dim, out_dim): super().__init__() self.conv = GCNConv(in_dim, out_dim) self.sigma = nn.PReLU(out_dim) self.reset_parameters() def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index): z = self.sigma(self.conv(x, edge_index)) return z
class Discriminator(torch.nn.Module): def __init__(self, dataset): super(Discriminator, self).__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, args.hidden) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index): x = F.relu(self.conv1(x, edge_index)) # LAYER 1 z = self.conv2(x, edge_index) # LAYER 2 return z
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, hidden) self.conv2 = GCNConv(hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data, edge_index): x = data.x x = F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
class PyGGCN(torch.nn.Module): def __init__(self, input_size, num_class=1, hidden_size=64): super(PyGGCN, self).__init__() self.conv1 = GCNConv(input_size, hidden_size) self.conv2 = GCNConv(hidden_size, num_class) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edges): x, edge_index = x, edges.T x = F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return x
class GCNNet(torch.nn.Module): def __init__(self, input_size, output_size, hidden_size=512): super(GCNNet, self).__init__() self.conv1 = GCNConv(input_size, hidden_size) self.conv2 = GCNConv(hidden_size, output_size) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, feature, edge_index): x = F.dropout(feature, p=0.5, training=self.training) x = F.elu(self.conv1(x, edge_index)) x = F.dropout(x, p=0.5, training=self.training) x = self.conv2(x, edge_index) return x
class Net(torch.nn.Module): def __init__(self, dataset): super().__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, hidden) # self.conv2 = GCNConv(hidden, int(dataset[0].num_class)) self.conv2 = GCNConv(dataset.num_features, int(dataset[0].num_class)) self.lin = Linear(int(dataset[0].num_class),int(dataset[0].num_class)) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, data, edge_index): x= data.x # x = F.relu(self.conv1(x, edge_index)) x = self.conv2(x, edge_index) return self.lin(F.relu(x))
class Net(torch.nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = GCNConv(data.num_features, 16) self.conv2 = GCNConv(16, data.num_classes) # self.conv1 = ChebConv(data.x.size(1), 16, K=2) # self.conv2 = ChebConv(16, data.y.max().item() + 1, K=2) def forward(self): x = F.relu(self.conv1(data.x, data.edge_index)) x = F.dropout(x, training=self.training) x = self.conv2(x, data.edge_index) return F.log_softmax(x, dim=1) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters()
class GCN(torch.nn.Module): def __init__(self, num_layers=2, hidden=16, features_num=16, num_class=2): super(GCN, self).__init__() # first layer self.conv1 = GCNConv(features_num, hidden) # list of 2nd - num_layers layers self.convs = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) # fully connected layers self.lin2 = Linear(hidden, num_class) self.first_lin = Linear(features_num, hidden) def reset_parameters(self): # clear weights self.first_lin.reset_parameters() self.conv1.reset_parameters() for conv in self.convs: conv.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight # fully connected layer + relu x = F.relu(self.first_lin(x)) # dropout layer x = F.dropout(x, p=0.5, training=self.training) # GCN layers for conv in self.convs: x = F.relu(conv(x, edge_index, edge_weight=edge_weight)) # Another dropout x = F.dropout(x, p=0.5, training=self.training) # second FC layer x = self.lin2(x) # Softmax return F.log_softmax(x, dim=-1) def __repr__(self): return self.__class__.__name__
class GCNNet(nn.Module): def __init__(self, dataset): super(GCNNet, self).__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) self.reset_parameters() def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, x, edge_index, training=None): x = F.relu(self.conv1(x, edge_index)) training = self.training if training == None else training x = F.dropout(x, p=0.5, training=training) x = self.conv2(x, edge_index) return x
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = GCNConv(dataset.num_features, args.hidden) self.conv2 = GCNConv(args.hidden, dataset.num_classes) # self.score = GATScore(Linear(args.hidden*2, 1)) self.score = Linear(dataset.num_classes, 1) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data, pos_edge_index, neg_edge_index): x, edge_index, masked_nodes = data.x, data.train_edge_index, data.masked_nodes total_edge_index = torch.cat([pos_edge_index, neg_edge_index], dim=-1) # x_j = torch.index_select(x, 0, total_edge_index[0]) # x_i = torch.index_select(x, 0, total_edge_index[1]) # s0 = F.relu(self.score(x_i-x_j).squeeze()) # _, net_0 = torch.topk(s0, 10 ,-1 ) x = F.relu(self.conv1(x, edge_index)) # masked_node = F.relu(self.conv1(masked_node, torch.zeros([2,1], dtype=edge_index.dtype, device= edge_index.device))) # s1 = self.score(x, masked_node) # x_j = torch.index_select(x, 0, total_edge_index[0]) # x_i = torch.index_select(x, 0, total_edge_index[1]) # s1 = x_i-x_j # 2 layer x = self.conv2(x, edge_index) # masked_node = self.conv2(masked_node, torch.zeros([2,1], dtype=edge_index.dtype, device= edge_index.device)) # s3 = self.score2(x, masked_node) # x[data.cold_mask_node] = masked_node x_j = torch.index_select(x, 0, total_edge_index[0]) x_i = torch.index_select(x, 0, total_edge_index[1]) # cos = torch.nn.CosineSimilarity(dim=1, eps=1e-6) # output = cos(x_j, x_i) dist = x_j-x_i out = F.relu(self.score(dist).squeeze()) score_loss = torch.matmul(dist, self.score.weight.squeeze()).mean() # return torch.einsum("ef,ef->e", x_i, x_j) return out, score_loss, F.log_softmax(x, dim=1)
class GCNII(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout, alpha=0.5, theta=1.0, shared_weights=True): super(GCNII, self).__init__() self.conv_in = GCNConv(in_channels, hidden_channels, normalize=False) self.convs = torch.nn.ModuleList() for l in range(num_layers): self.convs.append( GCN2Conv(hidden_channels, alpha, theta, layer=l + 1, shared_weights=shared_weights, normalize=False)) self.conv_out = GCNConv(hidden_channels, out_channels, normalize=False) self.dropout = dropout def reset_parameters(self): self.conv_in.reset_parameters() self.conv_out.reset_parameters() for conv in self.convs: conv.reset_parameters() def forward(self, x, adj_t): x = F.relu(self.conv_in(x, adj_t)) x_0 = x for conv in self.convs: x = conv(x, x_0, adj_t) x = F.relu(x) x = F.dropout(x, p=self.dropout, training=self.training) # x = F.relu(self.convs[-1](x, x_0, adj_t)) return self.conv_out(x, adj_t)
class ASAP_Pool(torch.nn.Module): def __init__(self, dataset, num_layers, hidden, ratio=0.8, **kwargs): super(ASAP_Pool, self).__init__() if type(ratio) != list: ratio = [ratio for i in range(num_layers)] self.conv1 = GCNConv(dataset.num_features, hidden) self.pool1 = ASAP_Pooling(in_channels=hidden, ratio=ratio[0], **kwargs) self.convs = torch.nn.ModuleList() self.pools = torch.nn.ModuleList() for i in range(num_layers - 1): self.convs.append(GCNConv(hidden, hidden)) self.pools.append(ASAP_Pooling(in_channels=hidden, ratio=ratio[i], **kwargs)) self.lin1 = Linear(2 * hidden, hidden) # 2*hidden due to readout layer self.lin2 = Linear(hidden, dataset.num_classes) self.reset_parameters() def reset_parameters(self): self.conv1.reset_parameters() self.pool1.reset_parameters() for conv, pool in zip(self.convs, self.pools): conv.reset_parameters() pool.reset_parameters() self.lin1.reset_parameters() self.lin2.reset_parameters() def forward(self, data): x, edge_index, batch = data.x, data.edge_index, data.batch x = F.relu(self.conv1(x, edge_index)) x, edge_index, edge_weight, batch, perm = self.pool1(x=x, edge_index=edge_index, edge_weight=None, batch=batch) xs = readout(x, batch) for conv, pool in zip(self.convs, self.pools): x = F.relu(conv(x=x, edge_index=edge_index, edge_weight=edge_weight)) x, edge_index, edge_weight, batch, perm = pool(x=x, edge_index=edge_index, edge_weight=edge_weight, batch=batch) xs += readout(x, batch) x = F.relu(self.lin1(xs)) x = F.dropout(x, p=0.5, training=self.training) x = self.lin2(x) out = F.log_softmax(x, dim=-1) return out def __repr__(self): return self.__class__.__name__
class GNN_Block(torch.nn.Module): def __init__(self, in_channels, hidden_channels): super(GNN_Block, self).__init__() self.conv1 = GCNConv(in_channels, hidden_channels) self.conv2 = GCNConv(hidden_channels, hidden_channels) self.lin = Linear(hidden_channels + hidden_channels, hidden_channels) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() self.lin.reset_parameters() def forward(self, x, edge_index): x1 = F.relu(self.conv1(x, edge_index)) x2 = F.relu(self.conv2(x1, edge_index)) out = self.lin(torch.cat((x1, x2), -1)) return out