Esempio n. 1
0
class MultiGATBaseConvs(nn.Module):

    def __init__(self, input_feat_channel=512, n_head=16):
        super(MultiGATBaseConvs, self).__init__()
        self.n_head = n_head
        self.l1 = GATConv(in_feats=input_feat_channel, out_feats=256, num_heads=int(n_head / 2), residual=True)
        self.l2 = GATConv(in_feats=int(n_head / 2) * 256, out_feats=256, num_heads=n_head, residual=True)
        self.l3 = GATConv(in_feats=n_head * 256, out_feats=256, num_heads=n_head, residual=True)
        self.l4 = GATConv(in_feats=n_head * 256, out_feats=256, num_heads=n_head, residual=True)
        self.l5 = GATConv(in_feats=n_head * 256, out_feats=512, num_heads=1, residual=True)

    #         self.l = GATConv(in_feats=[n_head, 512], out_feats=[1, 512], num_heads=1)

    def forward(self, graph, feat):
        N = feat.shape[0]
        #         print(feat.shape)
        x = self.l1.forward(graph, feat)
        x1 = F.relu(x)
        #         print(x.shape)
        x = self.l2.forward(graph, x1.view(N, -1))
        x = F.relu(x)

        x = self.l3.forward(graph, x.view(N, -1))
        x = F.relu(x)

        x = self.l4.forward(graph, x.view(N, -1))
        x = F.relu(x)

        x = self.l5.forward(graph, x.view(N, -1))
        x = F.relu(x)

        return x.view(N, -1)
Esempio n. 2
0
class MultiGATBaseConvs(nn.Module):
    def __init__(self, input_feat_channel=512):
        super(MultiGATBaseConvs, self).__init__()
        n_head = 16
        self.l1 = GATConv(in_feats=input_feat_channel,
                          out_feats=512,
                          num_heads=n_head,
                          residual=True)
        self.l2 = GATConv(in_feats=n_head * 512,
                          out_feats=512,
                          num_heads=n_head,
                          residual=True)
        self.l3 = GATConv(in_feats=n_head * 512,
                          out_feats=512,
                          num_heads=n_head,
                          residual=True)
        self.l4 = GATConv(in_feats=n_head * 512,
                          out_feats=512,
                          num_heads=1,
                          residual=True)
        self.l1.forward = MethodType(forward2, self.l1)
        self.l2.forward = MethodType(forward2, self.l2)
        self.l3.forward = MethodType(forward2, self.l3)
        self.l4.forward = MethodType(forward2, self.l4)

#         self.l = GATConv(in_feats=[n_head, 512], out_feats=[1, 512], num_heads=1)

    def forward(self, graph, feat):
        N = feat.shape[0]
        #         print(feat.shape)
        x, _, _ = self.l1.forward(graph, feat)
        x1 = F.relu(x)
        #         print(x.shape)
        x, _, _ = self.l2.forward(graph, x1.view(N, -1))
        x = F.relu(x)

        x, _, _ = self.l3.forward(graph, x.view(N, -1))
        x = F.relu(x)

        x, attn, bef = self.l4.forward(graph, x.view(N, -1))
        x = F.relu(x)
        #         bef = F.relu(bef)
        #         print(x1-x)

        diff = (x1 - x).detach().cpu().numpy()
        plt.hist(np.asarray(diff).ravel(), bins=100)
        plt.ylabel('Freq.')
        plt.show()

        return x.view(N, -1), attn, bef.view(N, -1)