class APPNPNet(torch.nn.Module): def __init__(self, in_channels, hidden_channels, out_channels, num_iteration, mlp_layers, dropout, alpha=0.1): super(APPNPNet, self).__init__() self.mlp = torch.nn.ModuleList() self.mlp.append(torch.nn.Linear(in_channels, hidden_channels)) for _ in range(mlp_layers - 2): self.mlp.append(torch.nn.Linear(hidden_channels, hidden_channels)) self.mlp.append(torch.nn.Linear(hidden_channels, out_channels)) self.appnp = APPNP(num_iteration, alpha, dropout=dropout, normalize=False) def reset_parameters(self): self.appnp.reset_parameters() for linear in self.mlp: linear.reset_parameters() def forward(self, x, adj_t): for linear in self.mlp[:-1]: x = F.relu(linear(x)) x = self.mlp[-1](x) return self.appnp(x, adj_t)
class Net(torch.nn.Module): def __init__(self, dataset): super(Net, self).__init__() self.conv1 = APPNP(dataset.num_features, args.hidden, args.K, args.alpha) self.conv2 = APPNP(args.hidden, dataset.num_classes, args.K, args.alpha) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.dropout(x, p=args.dropout, training=self.training) x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=args.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)