コード例 #1
0
ファイル: agnn.py プロジェクト: zwytop/GraphEmbedding
 def __init__(self):
     super(Net, self).__init__()
     self.lin1 = torch.nn.Linear(dataset.num_features, 16)
     self.prop1 = AGNNConv(requires_grad=False)
     self.prop2 = AGNNConv(requires_grad=True)
     self.lin2 = torch.nn.Linear(16, dataset.num_classes)
コード例 #2
0
 def __init__(self, in_dim, out_dim):
     super(Breadth, self).__init__()
     self.gatconv = AGNNConv(requires_grad=True)
コード例 #3
0
 def __init__(self, in_dim, out_dim):
     super(Breadth, self).__init__()
     self.lin1 = torch.nn.Linear(in_dim, 16)
     self.prop1 = AGNNConv(requires_grad=False)
     self.prop2 = AGNNConv(requires_grad=True)
     self.lin2 = torch.nn.Linear(16, out_dim)
コード例 #4
0
    def __init__(self,
                 in_dim,
                 hidden_dim,
                 out_dim,
                 dropout=0.5,
                 name='gat',
                 heads=8,
                 residual=True):
        super(GNNModelPYG, self).__init__()
        self.dropout = dropout
        self.name = name
        self.residual = None
        if residual:
            if in_dim == out_dim:
                self.residual = Identity()
            else:
                self.residual = Linear(in_dim, out_dim)

        if name == 'gat':
            self.conv1 = GATConv(in_dim,
                                 hidden_dim,
                                 heads=heads,
                                 dropout=dropout)
            self.conv2 = GATConv(hidden_dim * heads,
                                 out_dim,
                                 heads=1,
                                 concat=False,
                                 dropout=dropout)
        elif name == 'gcn':
            self.conv1 = GCNConv(in_dim,
                                 hidden_dim,
                                 cached=True,
                                 normalize=True,
                                 add_self_loops=False)
            self.conv2 = GCNConv(hidden_dim,
                                 out_dim,
                                 cached=True,
                                 normalize=True,
                                 add_self_loops=False)
        elif name == 'cheb':
            self.conv1 = ChebConv(in_dim, hidden_dim, K=2)
            self.conv2 = ChebConv(hidden_dim, out_dim, K=2)
        elif name == 'spline':
            self.conv1 = SplineConv(in_dim, hidden_dim, dim=1, kernel_size=2)
            self.conv2 = SplineConv(hidden_dim, out_dim, dim=1, kernel_size=2)
        elif name == 'gin':
            self.conv1 = GINConv(
                Sequential(Linear(in_dim, hidden_dim), ReLU(),
                           Linear(hidden_dim, hidden_dim)))
            self.conv2 = GINConv(
                Sequential(Linear(hidden_dim, hidden_dim), ReLU(),
                           Linear(hidden_dim, out_dim)))
        elif name == 'unet':
            self.conv1 = GraphUNet(in_dim, hidden_dim, out_dim, depth=3)
        elif name == 'agnn':
            self.lin1 = Linear(in_dim, hidden_dim)
            self.conv1 = AGNNConv(requires_grad=False)
            self.conv2 = AGNNConv(requires_grad=True)
            self.lin2 = Linear(hidden_dim, out_dim)
        else:
            raise NotImplemented("""
            Unknown model name. Choose from gat, gcn, cheb, spline, gin, unet, agnn."""
                                 )