示例#1
0
 def __init__(self, in_feats, n_hidden, n_classes):
     super().__init__()
     self.layers = nn.ModuleList()
     self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
     self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
     self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
     self.dropout = nn.Dropout(0.5)
示例#2
0
 def __init__(self, in_feats, hid_feats, out_feats):
     super().__init__()
     self.conv1 = dglnn.SAGEConv(in_feats=in_feats,
                                 out_feats=hid_feats,
                                 aggregator_type="mean")
     self.conv2 = dglnn.SAGEConv(in_feats=hid_feats,
                                 out_feats=out_feats,
                                 aggregator_type="mean")
示例#3
0
 def __init__(self, in_size, hid_size, out_size):
     super().__init__()
     self.layers = nn.ModuleList()
     # three-layer GraphSAGE-mean
     self.layers.append(dglnn.SAGEConv(in_size, hid_size, 'mean'))
     self.layers.append(dglnn.SAGEConv(hid_size, hid_size, 'mean'))
     self.layers.append(dglnn.SAGEConv(hid_size, out_size, 'mean'))
     self.dropout = nn.Dropout(0.5)
     self.hid_size = hid_size
     self.out_size = out_size
示例#4
0
 def __init__(self, in_feats, hid_feats, out_feats):
     super().__init__()
     self.conv1 = dglnn.SAGEConv(in_feats=in_feats,
                                 out_feats=hid_feats,
                                 aggregator_type='pool')
     self.conv2 = dglnn.SAGEConv(in_feats=hid_feats,
                                 out_feats=1000,
                                 aggregator_type='pool')
     self.conv3 = dglnn.SAGEConv(in_feats=1000,
                                 out_feats=out_feats,
                                 aggregator_type='pool')
示例#5
0
 def __init__(self, in_feats, n_hidden):
     super().__init__()
     self.n_hidden = n_hidden
     self.layers = nn.ModuleList()
     self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
     self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
     self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
     self.predictor = nn.Sequential(nn.Linear(n_hidden,
                                              n_hidden), nn.ReLU(),
                                    nn.Linear(n_hidden, n_hidden),
                                    nn.ReLU(), nn.Linear(n_hidden, 1))
示例#6
0
 def init(self, in_feats, n_hidden, n_classes, n_layers, activation,
          dropout):
     self.n_layers = n_layers
     self.n_hidden = n_hidden
     self.n_classes = n_classes
     self.layers = nn.ModuleList()
     self.layers.append(dglnn.SAGEConv(in_feats, n_hidden, 'mean'))
     for i in range(1, n_layers - 1):
         self.layers.append(dglnn.SAGEConv(n_hidden, n_hidden, 'mean'))
     self.layers.append(dglnn.SAGEConv(n_hidden, n_classes, 'mean'))
     self.dropout = nn.Dropout(dropout)
     self.activation = activation
示例#7
0
 def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation,
              dropout):
     super().__init__()
     self.n_layers = n_layers
     self.n_hidden = n_hidden
     self.n_classes = n_classes
     self.layers = nn.ModuleList()
     self.bns = nn.ModuleList()
     self.res_linears = nn.ModuleList()
     self.layers.append(
         dglnn.SAGEConv(in_feats,
                        n_hidden,
                        'mean',
                        bias=False,
                        feat_drop=dropout))
     self.bns.append(torch.nn.BatchNorm1d(n_hidden))
     self.res_linears.append(torch.nn.Linear(in_feats, n_hidden))
     for i in range(1, n_layers - 1):
         self.layers.append(
             dglnn.SAGEConv(n_hidden,
                            n_hidden,
                            'mean',
                            bias=False,
                            feat_drop=dropout))
         self.bns.append(torch.nn.BatchNorm1d(n_hidden))
         self.res_linears.append(torch.nn.Identity())
     self.layers.append(
         dglnn.SAGEConv(n_hidden,
                        n_hidden,
                        'mean',
                        bias=False,
                        feat_drop=dropout))
     self.bns.append(torch.nn.BatchNorm1d(n_hidden))
     self.res_linears.append(torch.nn.Identity())
     self.mlp = MLP(in_feats + n_hidden * n_layers,
                    2 * n_classes,
                    n_classes,
                    num_layers=2,
                    bn=True,
                    end_up_with_fc=True,
                    act='LeakyReLU')
     self.dropout = nn.Dropout(dropout)
     self.activation = activation
     self.profile = locals()
示例#8
0
def test_sage_conv_bi(idtype, g, aggre_type):
    g = g.astype(idtype)
    dst_dim = 5 if aggre_type != 'gcn' else 10
    sage = nn.SAGEConv((10, dst_dim), 2, aggre_type)
    feat = (F.randn((g.number_of_src_nodes(), 10)),
            F.randn((g.number_of_dst_nodes(), dst_dim)))

    init_params = sage.init(jax.random.PRNGKey(2666), g, feat)
    h = sage.apply(init_params, g, feat)
    assert h.shape[-1] == 2
    assert h.shape[0] == g.number_of_dst_nodes()
    def __init__(self, src_id, dst_id, in_c, hid_c, n_layers, device):
        super(SAGEModel, self).__init__()
        self.graph = dgl.graph((src_id, dst_id), device=device)

        self.gcn = nn.ModuleList([
            gnn.SAGEConv(in_c if i == 0 else hid_c, hid_c, "pool")
            for i in range(n_layers)
        ])

        self.residual = nn.ModuleList([
            nn.Identity() if i != 0 else nn.Linear(in_c, hid_c)
            for i in range(n_layers)
        ])
示例#10
0
def test_sage_conv2(idtype):
    # TODO: add test for blocks
    # Test the case for graphs without edges
    g = dgl.heterograph({('_U', '_E', '_V'): ([], [])}, {'_U': 5, '_V': 3})
    g = g.astype(idtype)
    sage = nn.SAGEConv((3, 3), 2, 'gcn')
    feat = (F.randn((5, 3)), F.randn((3, 3)))

    init_params = sage.init(
        jax.random.PRNGKey(2666), g,
        (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
    h = sage.apply(init_params, g,
                   (F.copy_to(feat[0], F.ctx()), F.copy_to(feat[1], F.ctx())))
    assert h.shape[-1] == 2
    assert h.shape[0] == 3

    for aggre_type in ['mean', 'pool']:

        sage = nn.SAGEConv((3, 1), 2, aggre_type)
        feat = (F.randn((5, 3)), F.randn((3, 1)))
        init_params = sage.init(jax.random.PRNGKey(2666), g, feat)
        h = sage.apply(init_params, g, feat)
        assert h.shape[-1] == 2
        assert h.shape[0] == 3
示例#11
0
def test_sage_conv(idtype, g, aggre_type):
    sage = nn.SAGEConv(5, 10, aggre_type)
    feat = F.randn((g.number_of_nodes(), 5))
    init_params = sage.init(jax.random.PRNGKey(2666), g, feat)
    h = sage.apply(init_params, g, feat)
    assert h.shape[-1] == 10