def __init__(self, in_dim, hidden_dim, n_classes,hidden_layers,n_steps,readout, activation_func,dropout,grid,device): super(Classifier, self).__init__() self.device = device self.readout = readout self.layers = nn.ModuleList() self.batch_norms = nn.ModuleList() self.grid = grid # input layer self.layers.append(conv.GatedGraphConv(in_dim,hidden_dim,n_steps,1)) self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # hidden layers for k in range(0,hidden_layers): self.layers.append(conv.GatedGraphConv(hidden_dim,hidden_dim,n_steps,1)) self.batch_norms.append(nn.BatchNorm1d(hidden_dim)) # dropout layer self.dropout=nn.Dropout(p=dropout) # last layer if self.readout == 'max': self.readout_fcn = conv.MaxPooling() elif self.readout == 'mean': self.readout_fcn = conv.AvgPooling() elif self.readout == 'sum': self.readout_fcn = conv.SumPooling() elif self.readout == 'gap': self.readout_fcn = conv.GlobalAttentionPooling(nn.Linear(hidden_dim,1),nn.Linear(hidden_dim,hidden_dim*2)) elif self.readout == 'sort': self.readout_fcn = conv.SortPooling(100) elif self.readout == 'set': self.readout_fcn = conv.Set2Set(hidden_dim,2,2) else: self.readout_fcn = SppPooling(hidden_dim,self.grid) if self.readout == 'spp': self.classify = nn.Sequential( nn.Dropout(), nn.Linear(hidden_dim * self.grid * self.grid, hidden_dim), nn.ReLU(inplace=True), nn.Linear(hidden_dim, n_classes), ) elif self.readout == 'sort': self.classify = nn.Sequential( nn.Dropout(), nn.Linear(hidden_dim*100, n_classes), ) else: var=hidden_dim if self.readout == 'gap' or self.readout == 'set': var*=2 self.classify = nn.Linear(var, n_classes)
def test_gated_graph_conv(): ctx = F.ctx() g = dgl.DGLGraph(sp.sparse.random(100, 100, density=0.1), readonly=True) ggconv = nn.GatedGraphConv(5, 10, 5, 3) etypes = th.arange(g.number_of_edges()) % 3 feat = F.randn((100, 5)) ggconv = ggconv.to(ctx) etypes = etypes.to(ctx) h = ggconv(g, feat, etypes) # current we only do shape check assert h.shape[-1] == 10
def test_gated_graph_conv(g, idtype): ctx = F.ctx() g = g.astype(idtype).to(ctx) ggconv = nn.GatedGraphConv(5, 10, 5, 3) etypes = th.arange(g.number_of_edges()) % 3 feat = F.randn((g.number_of_nodes(), 5)) ggconv = ggconv.to(ctx) etypes = etypes.to(ctx) h = ggconv(g, feat, etypes) # current we only do shape check assert h.shape[-1] == 10
def test_gated_graph_conv_one_etype(g, idtype): ctx = F.ctx() g = g.astype(idtype).to(ctx) ggconv = nn.GatedGraphConv(5, 10, 5, 1) etypes = th.zeros(g.number_of_edges()) feat = F.randn((g.number_of_nodes(), 5)) ggconv = ggconv.to(ctx) etypes = etypes.to(ctx) h = ggconv(g, feat, etypes) h2 = ggconv(g, feat) # current we only do shape check assert F.allclose(h, h2) assert h.shape[-1] == 10