Пример #1
0
class TAG_Net(torch.nn.Module):
    def __init__(self, features_num, num_class, hidden, dropout):
        super(TAG_Net, self).__init__()
        self.dropout = dropout
        self.conv1 = TAGConv(features_num, hidden)
        self.conv2 = TAGConv(hidden, num_class)

    def reset_parameters(self):
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = F.relu(self.conv1(x, edge_index))
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)
Пример #2
0
class TAGNet(nn.Module):
    def __init__(self, num_feature, num_class, num_layers=2, k=3, hidden=64, drop=0.5, use_edge_weight=True):
        super(TAGNet, self).__init__()
        self.conv0 = TAGConv(num_feature, hidden, K=k)
        self.conv1 = TAGConv(hidden, hidden, K=k)
        self.conv2 = TAGConv(hidden, num_class, K=k)
        self.n_layer = num_layers
        self.use_edge_weight = use_edge_weight
        self.drop = drop

    def reset_parameters(self):
        self.conv0.reset_parameters()
        self.conv1.reset_parameters()
        self.conv2.reset_parameters()

    def forward(self, data):
        x, edge_index, edge_weight = data.x, data.edge_index, data.edge_attr.squeeze(1)

        for i in range(self.n_layer - 1):
            conv = self.conv0 if i == 0 else self.conv1
            x = conv(x, edge_index, edge_weight) if self.use_edge_weight else \
                conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.drop, training=self.training)

        x = self.conv2(x, edge_index, edge_weight) if self.use_edge_weight else \
            self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)