class CLS(torch.nn.Module): def __init__(self, d_in, d_out): super(CLS, self).__init__() self.conv = GCNConv(d_in, d_out) def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index, mask=None): x = self.conv(x, edge_index) x = F.log_softmax(x, dim=1) return x
class CRD(torch.nn.Module): def __init__(self, d_in, d_out, p): super(CRD, self).__init__() self.conv = GCNConv(d_in, d_out) self.p = p def reset_parameters(self): self.conv.reset_parameters() def forward(self, x, edge_index, mask=None): x = F.relu(self.conv(x, edge_index)) x = F.dropout(x, p=self.p, training=self.training) return x
class GCNEncoder(nn.Module): def __init__(self, in_channels, hidden_channels, out_channels): super(GCNEncoder, self).__init__() self.gcn_shared = GCNConv(in_channels, hidden_channels, cached=True) self.gcn_mu = GCNConv(hidden_channels, out_channels, cached=True) self.gcn_logvar = GCNConv(hidden_channels, out_channels, cached=True) def reset_parameters(self): self.gcn_shared.reset_parameters() self.gcn_mu.reset_parameters() self.gcn_logvar.reset_parameters() def forward(self, x, edge_index): x = F.relu(self.gcn_shared(x, edge_index)) mu = self.gcn_mu(x, edge_index) logvar = self.gcn_logvar(x, edge_index) return mu, logvar