Beispiel #1
0
class GraphSAGE(nn.Module):
    def __init__(self, node_in_dim,node_out_dim=64,
                 heads=1,
                 dropout=0.1
                 ):
        super(GraphSAGE, self).__init__()
        self.conv1 = SAGEConv(node_in_dim, node_out_dim)
        self.conv2 = SAGEConv(node_out_dim, node_out_dim)

    def forward(self, x, edge_index, edge_attr=None):
        x = self.conv1.forward(x, edge_index)
        x = self.conv2.forward(x, edge_index)
        return x