class TAG_Net(torch.nn.Module): def __init__(self, features_num, num_class, hidden, dropout): super(TAG_Net, self).__init__() self.dropout = dropout self.conv1 = TAGConv(features_num, hidden) self.conv2 = TAGConv(hidden, num_class) def reset_parameters(self): self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index = data.x, data.edge_index x = F.relu(self.conv1(x, edge_index)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1)
class TAGNet(nn.Module): def __init__(self, num_feature, num_class, num_layers=2, k=3, hidden=64, drop=0.5, use_edge_weight=True): super(TAGNet, self).__init__() self.conv0 = TAGConv(num_feature, hidden, K=k) self.conv1 = TAGConv(hidden, hidden, K=k) self.conv2 = TAGConv(hidden, num_class, K=k) self.n_layer = num_layers self.use_edge_weight = use_edge_weight self.drop = drop def reset_parameters(self): self.conv0.reset_parameters() self.conv1.reset_parameters() self.conv2.reset_parameters() def forward(self, data): x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr.squeeze(1) for i in range(self.n_layer - 1): conv = self.conv0 if i == 0 else self.conv1 x = conv(x, edge_index, edge_weight) if self.use_edge_weight else \ conv(x, edge_index) x = F.relu(x) x = F.dropout(x, p=self.drop, training=self.training) x = self.conv2(x, edge_index, edge_weight) if self.use_edge_weight else \ self.conv2(x, edge_index) return F.log_softmax(x, dim=1)