class MultiGATBaseConvs(nn.Module): def __init__(self, input_feat_channel=512, n_head=16): super(MultiGATBaseConvs, self).__init__() self.n_head = n_head self.l1 = GATConv(in_feats=input_feat_channel, out_feats=256, num_heads=int(n_head / 2), residual=True) self.l2 = GATConv(in_feats=int(n_head / 2) * 256, out_feats=256, num_heads=n_head, residual=True) self.l3 = GATConv(in_feats=n_head * 256, out_feats=256, num_heads=n_head, residual=True) self.l4 = GATConv(in_feats=n_head * 256, out_feats=256, num_heads=n_head, residual=True) self.l5 = GATConv(in_feats=n_head * 256, out_feats=512, num_heads=1, residual=True) # self.l = GATConv(in_feats=[n_head, 512], out_feats=[1, 512], num_heads=1) def forward(self, graph, feat): N = feat.shape[0] # print(feat.shape) x = self.l1.forward(graph, feat) x1 = F.relu(x) # print(x.shape) x = self.l2.forward(graph, x1.view(N, -1)) x = F.relu(x) x = self.l3.forward(graph, x.view(N, -1)) x = F.relu(x) x = self.l4.forward(graph, x.view(N, -1)) x = F.relu(x) x = self.l5.forward(graph, x.view(N, -1)) x = F.relu(x) return x.view(N, -1)
class MultiGATBaseConvs(nn.Module): def __init__(self, input_feat_channel=512): super(MultiGATBaseConvs, self).__init__() n_head = 16 self.l1 = GATConv(in_feats=input_feat_channel, out_feats=512, num_heads=n_head, residual=True) self.l2 = GATConv(in_feats=n_head * 512, out_feats=512, num_heads=n_head, residual=True) self.l3 = GATConv(in_feats=n_head * 512, out_feats=512, num_heads=n_head, residual=True) self.l4 = GATConv(in_feats=n_head * 512, out_feats=512, num_heads=1, residual=True) self.l1.forward = MethodType(forward2, self.l1) self.l2.forward = MethodType(forward2, self.l2) self.l3.forward = MethodType(forward2, self.l3) self.l4.forward = MethodType(forward2, self.l4) # self.l = GATConv(in_feats=[n_head, 512], out_feats=[1, 512], num_heads=1) def forward(self, graph, feat): N = feat.shape[0] # print(feat.shape) x, _, _ = self.l1.forward(graph, feat) x1 = F.relu(x) # print(x.shape) x, _, _ = self.l2.forward(graph, x1.view(N, -1)) x = F.relu(x) x, _, _ = self.l3.forward(graph, x.view(N, -1)) x = F.relu(x) x, attn, bef = self.l4.forward(graph, x.view(N, -1)) x = F.relu(x) # bef = F.relu(bef) # print(x1-x) diff = (x1 - x).detach().cpu().numpy() plt.hist(np.asarray(diff).ravel(), bins=100) plt.ylabel('Freq.') plt.show() return x.view(N, -1), attn, bef.view(N, -1)