예제 #1
0
    def __init__(self, in_dim, hidden_dim, n_classes,hidden_layers,n_steps,readout,
                 activation_func,dropout,grid,device):
        super(Classifier, self).__init__()
        self.device      = device
        self.readout     = readout
        self.layers      = nn.ModuleList()
        self.batch_norms = nn.ModuleList() 
        self.grid        = grid

        # input layer
        self.layers.append(conv.GatedGraphConv(in_dim,hidden_dim,n_steps,1))
        self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
                
        # hidden layers
        for k in range(0,hidden_layers):
            self.layers.append(conv.GatedGraphConv(hidden_dim,hidden_dim,n_steps,1))
            self.batch_norms.append(nn.BatchNorm1d(hidden_dim))
            
        # dropout layer
        self.dropout=nn.Dropout(p=dropout)
                
        # last layer
        if self.readout == 'max':
            self.readout_fcn = conv.MaxPooling()
        elif self.readout == 'mean':
            self.readout_fcn = conv.AvgPooling()
        elif self.readout == 'sum':
            self.readout_fcn = conv.SumPooling()
        elif self.readout == 'gap':
            self.readout_fcn = conv.GlobalAttentionPooling(nn.Linear(hidden_dim,1),nn.Linear(hidden_dim,hidden_dim*2))
        elif self.readout == 'sort':
            self.readout_fcn = conv.SortPooling(100)
        elif self.readout == 'set':
            self.readout_fcn = conv.Set2Set(hidden_dim,2,2)
        else:
            self.readout_fcn = SppPooling(hidden_dim,self.grid)
        
        if self.readout == 'spp':
            self.classify = nn.Sequential(
                nn.Dropout(),
                nn.Linear(hidden_dim * self.grid * self.grid, hidden_dim),
                nn.ReLU(inplace=True),
                nn.Linear(hidden_dim, n_classes),
            )
        elif self.readout == 'sort':
            self.classify = nn.Sequential(
                nn.Dropout(),
                nn.Linear(hidden_dim*100, n_classes),
            )
        else:
            var=hidden_dim
            if self.readout == 'gap' or self.readout == 'set':
                var*=2
            self.classify = nn.Linear(var, n_classes)
예제 #2
0
파일: test_nn.py 프로젝트: samzhaoziran/dgl
def test_gated_graph_conv():
    ctx = F.ctx()
    g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True)
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((100, 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    # current we only do shape check
    assert h.shape[-1] == 10
예제 #3
0
파일: test_nn.py 프로젝트: yangce0224/dgl
def test_gated_graph_conv(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 3)
    etypes = th.arange(g.number_of_edges()) % 3
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    # current we only do shape check
    assert h.shape[-1] == 10
예제 #4
0
파일: test_nn.py 프로젝트: lygztq/dgl
def test_gated_graph_conv_one_etype(g, idtype):
    ctx = F.ctx()
    g = g.astype(idtype).to(ctx)
    ggconv = nn.GatedGraphConv(5, 10, 5, 1)
    etypes = th.zeros(g.number_of_edges())
    feat = F.randn((g.number_of_nodes(), 5))
    ggconv = ggconv.to(ctx)
    etypes = etypes.to(ctx)

    h = ggconv(g, feat, etypes)
    h2 = ggconv(g, feat)
    # current we only do shape check
    assert F.allclose(h, h2)
    assert h.shape[-1] == 10