예제 #1
0
    def __init__(self, in_dim, hidden_dim, n_classes, hidden_layers, readout,
                 activation_func, dropout, grid, device):
        super(Classifier, self).__init__()
        self.device = device
        self.readout = readout
        self.layers = nn.ModuleList()
        self.batch_norms = nn.ModuleList()
        self.grid = grid

        # input layer
        self.layers.append(conv.EdgeConv(in_dim, hidden_dim))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # hidden layers
        for k in range(0, hidden_layers):
            self.layers.append(conv.EdgeConv(hidden_dim, hidden_dim))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))

        # dropout layer
        self.dropout = nn.Dropout(p=dropout)

        # last layer
        if self.readout == 'max':
            self.readout_fcn = conv.MaxPooling()
        elif self.readout == 'mean':
            self.readout_fcn = conv.AvgPooling()
        elif self.readout == 'sum':
            self.readout_fcn = conv.SumPooling()
        elif self.readout == 'gap':
            self.readout_fcn = conv.GlobalAttentionPooling(
                nn.Linear(hidden_dim, 1), nn.Linear(hidden_dim,
                                                    hidden_dim * 2))
        else:
            self.readout_fcn = SppPooling(hidden_dim, self.grid)

        if self.readout == 'spp':
            self.classify = nn.Sequential(
                nn.Dropout(),
                nn.Linear(hidden_dim * self.grid * self.grid, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, n_classes),
            )
        else:
            var = hidden_dim
            if self.readout == 'gap':
                var *= 2
            self.classify = nn.Linear(var, n_classes)
예제 #2
0
파일: test_nn.py 프로젝트: yangce0224/dgl
def test_edge_conv(g, idtype):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)
    h0 = F.randn((g.number_of_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_nodes(), 2)
예제 #3
0
파일: test_nn.py 프로젝트: lygztq/dgl
def test_edge_conv_bi(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
    print(edge_conv)
    h0 = F.randn((g.number_of_src_nodes(), 5))
    x0 = F.randn((g.number_of_dst_nodes(), 5))
    h1 = edge_conv(g, (h0, x0))
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
예제 #4
0
파일: test_nn.py 프로젝트: lygztq/dgl
def test_edge_conv(g, idtype, out_dim):
    g = g.astype(idtype).to(F.ctx())
    ctx = F.ctx()
    edge_conv = nn.EdgeConv(5, out_dim).to(ctx)
    print(edge_conv)

    # test pickle
    th.save(edge_conv, tmp_buffer)

    h0 = F.randn((g.number_of_src_nodes(), 5))
    h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_dst_nodes(), out_dim)
예제 #5
0
파일: test_nn.py 프로젝트: weibao918/dgl
def test_edge_conv(g):
    ctx = F.ctx()

    edge_conv = nn.EdgeConv(5, 2).to(ctx)
    print(edge_conv)

    # test #1: basic
    h0 = F.randn((g.number_of_src_nodes(), 5))
    if not g.is_homograph() and not g.is_block:
        # bipartite
        h1 = edge_conv(g, (h0, h0[:10]))
    else:
        h1 = edge_conv(g, h0)
    assert h1.shape == (g.number_of_dst_nodes(), 2)