class ScoreNetwork(nn.Module):

    def __init__(self, top_k=4):
        super(ScoreNetwork, self).__init__()
        self.no_features = 128
        self.input_node_feat = 4 * top_k + top_k
        self.dropout_ratio = 0.6

        self.conv1 = EdgeConvRot(self.input_node_feat, 4, self.no_features)
        self.conv2 = EdgeConvRot(self.no_features, self.no_features, self.no_features)
        self.conv2_pooling = SAGPooling(in_channels=self.no_features, GNN=GATConv)

        self.conv3 = EdgeConvRot(self.no_features, self.no_features, self.no_features)
        self.conv3_pooling = SAGPooling(in_channels=self.no_features, GNN=GATConv)

        self.conv4 = EdgeConvRot(self.no_features, self.no_features, self.no_features)
        self.conv4_pooling = SAGPooling(in_channels=self.no_features, GNN=GATConv)

        self.lin1 = torch.nn.Linear(self.no_features*2, self.no_features)
        self.lin2 = torch.nn.Linear(self.no_features, self.no_features//2)
        self.lin3 = torch.nn.Linear(self.no_features//2, top_k)

    def forward(self, node_feat, node_level, edge_index, edge_feat):

        N = node_feat.shape[0]
        K = node_feat.shape[1]

        E = edge_feat.shape[0]
        # edge_feat_ = update_attr_batch(node_feat, edge_index, edge_feat).view(-1, 4)

        node_feat = node_feat.view(N, -1)

        node_feat = torch.cat([node_feat, node_level], dim=1)

        x1, edge_x1 = self.conv1(node_feat, edge_index, edge_feat)
        x1, edge_x1 = F.relu(x1), F.relu(edge_x1)

        x2, edge_x2 = self.conv2(x1, edge_index, edge_x1)
        x2, edge_x2 = F.relu(x2), F.relu(edge_x2)

        x2_pool, edge_x2_index_pool, edge_x2_pool, batch, p2_to_x2, _ = self.conv2_pooling.forward(
            x2, edge_index, edge_attr=edge_x2)
        l1 = torch.cat([gmp(x2_pool, batch), gap(x2_pool, batch)], dim=1)

        x3, edge_x3 = self.conv3(x2_pool, edge_x2_index_pool, edge_x2_pool)
        x3, edge_x3 = F.relu(x3), F.relu(edge_x3)
        x3_pool, edge_x3_index_pool, edge_x3_pool, batch, p3_to_x3, _ = self.conv3_pooling.forward(
            x3, edge_x2_index_pool, edge_attr=edge_x3)
        l2 = torch.cat([gmp(x3_pool, batch), gap(x3_pool, batch)], dim=1)

        x4, edge_x4 = self.conv4(x3_pool, edge_x3_index_pool, edge_x3_pool)
        x4, edge_x4 = F.relu(x4), F.relu(edge_x4)
        x4_pool, edge_x4_index_pool, edge_x4_pool, batch, p4_to_x4, _ = self.conv4_pooling.forward(
            x4, edge_x3_index_pool, edge_attr=edge_x4)
        l3 = torch.cat([gmp(x4_pool, batch), gap(x4_pool, batch)], dim=1)
        l = l1 + l2 + l3

        x = F.relu(self.lin1(l))
        x = F.dropout(x, p=self.dropout_ratio, training=self.training)
        x = F.relu(self.lin2(x))
        x = self.lin3(x)

        return x
class PoolingFineNetWithAppearance(torch.nn.Module):

    def __init__(self, in_node_feat, in_edge_feat):
        super(PoolingFineNetWithAppearance, self).__init__()
        self.no_features = 64  # More features for large dataset
        self.conv1 = EdgeConvRot(in_node_feat + 4, in_edge_feat + 4, self.no_features)
        self.conv2 = EdgeConvRot(self.no_features, self.no_features + 4, self.no_features)
        self.conv3 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)

        self.conv3_sub_pre = EdgeConvRot(self.no_features, self.no_features, self.no_features * 2)
        self.conv3_sub_pooling = SAGPooling(in_channels=2*self.no_features, GNN=GATConv)
        self.conv3_sub1 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features*2)
        self.conv3_sub2 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)

        # self.conv3_sub_pooling = SAGPooling(in_channels=self.no_features, GNN=GATConv)
        # self.conv3_sub_pre = EdgeConvRot(self.no_features, self.no_features, self.no_features * 2)
        self.conv3_subsub_pooling = SAGPooling(in_channels=2*self.no_features, GNN=GATConv)
        self.conv3_subsub = EdgeConvRot(2 * self.no_features, 2 * self.no_features, 2*self.no_features)
        self.conv3_subsub2 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, 2*self.no_features)

        self.conv4 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)
        self.conv5 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)
        self.lin1 = Linear(self.no_features, 4)
        self.lin2 = Linear(self.no_features, 1)

        self.m = torch.nn.Sigmoid()

    def forward(self, data):
        x_org, edge_index, edge_attr, gt_q, beta, node_feat, edge_feat = data
        N = x_org.shape[0]
        E = edge_attr.shape[0]

        edge_attr_mod = update_attr(x_org, edge_index, edge_attr[:, :4])
        edge_feat = edge_feat.view(E, -1)
        node_feat = node_feat.view(N, -1)

        x1, edge_x1 = self.conv1(torch.cat([node_feat, x_org], dim=1), edge_index, torch.cat([edge_feat, edge_attr_mod], dim=1))
        x1 = F.relu(x1)
        edge_x1 = F.relu(edge_x1)

        x2, edge_x2 = self.conv2(x1, edge_index, torch.cat([edge_attr_mod, edge_x1], dim=1))
        x2 = F.relu(x2)
        edge_x2 = F.relu(edge_x2)

        x3, edge_x3 = self.conv3(torch.cat([x2, x1], dim=1), edge_index, torch.cat([edge_x2, edge_x1], dim=1))
        x3 = F.relu(x3)
        edge_x3 = F.relu(edge_x3)

        # Branch sub
        x3_s1, edge_x3_s1 = self.conv3_sub_pre(x3, edge_index, edge_x3)
        x3_s1, edge_x3_s1 = F.relu(x3_s1), F.relu(edge_x3_s1)

        x3_s1_pool, edge_s1_index_pool, edge_x3_s1_pool, batch, s1_to_ori, scores = self.conv3_sub_pooling.forward(x3_s1, edge_index, edge_attr=edge_x3_s1)
        x3_s2, edge_x3_s2 = self.conv3_sub1(x3_s1_pool, edge_s1_index_pool, edge_x3_s1_pool)
        x3_s2, edge_x3_s2 = F.relu(x3_s2), F.relu(edge_x3_s2)

        # Branch x3b
        x3_ss_pool, edge_ss_index_pool, edge_x3_ss_pool, batch1, ss_to_s1, _ = self.conv3_subsub_pooling.forward(x3_s2, edge_s1_index_pool, edge_attr=edge_x3_s2)
        x3_ss1, edge_x3_ss1 = self.conv3_subsub.forward(x3_ss_pool, edge_ss_index_pool, edge_x3_ss_pool)
        x3_ss1, edge_x3_ss1 = F.relu(x3_ss1), F.relu(edge_x3_ss1)

        x3_ss2, edge_x3_ss2 = self.conv3_subsub.forward(x3_ss1, edge_ss_index_pool, edge_x3_ss1)
        x3_ss2, edge_x3_ss2 = F.relu(x3_ss2), F.relu(edge_x3_ss2)
        x3_s2_ = x3_s2.clone()
        x3_s2_[ss_to_s1, :] = x3_s2_[ss_to_s1, :] + x3_ss2

        x3_s3, edge_x3_s3 = self.conv3_sub2(x3_s2_, edge_s1_index_pool, edge_x3_s2)
        x3_s3, edge_x3_s3 = F.relu(x3_s3), F.relu(edge_x3_s3)
        x3[s1_to_ori] = x3[s1_to_ori] + x3_s3

        # End route
        x4, edge_x4 = self.conv4(torch.cat([x3, x2], dim=1), edge_index, torch.cat([edge_x3, edge_x2], dim=1))
        x4, edge_x4 = F.relu(x4), F.relu(edge_x4)
        x5, edge_x5 = self.conv4(torch.cat([x4, x3], dim=1), edge_index, torch.cat([edge_x4, edge_x3], dim=1))
        x5, edge_x5 = F.relu(x5), F.relu(edge_x5)

        x = self.lin1(x5)
        # x = x + x_org  # qmul(x, x_org)
        x = qmul(x, x_org)
        x = F.normalize(x, p=2, dim=1)

        out_res = self.lin2(edge_x5)

        #     loss1 = inv_q(edge_model(data.y, edge_index)) - edge_model(x, edge_index)
        loss1 = qmul(inv_q(edge_model(gt_q, edge_index)), edge_model(x, edge_index))
        loss1 = F.normalize(loss1, p=2, dim=1)
        #        loss1 = qmul(inv_q(edge_attr[:, :4]), edge_model(x, edge_index))
        loss1 = my_smooth_l1_loss_new(loss1[:, 0:], beta, edge_index, x_org.size)

        return x, loss1, beta, out_res, (x1, x2, x3, x4, x5), (edge_x1, edge_x2, edge_x3, edge_x4, edge_x5)  # node_model(x, batch),
class AppearancePoolFusion(nn.Module):

    def __init__(self, in_node_feat, in_edge_feat, inplace=True, num_opt=4):
        super(AppearancePoolFusion, self).__init__()
        self.no_features = 128
        self.conv1 = EdgeConvRot(in_node_feat, in_edge_feat, self.no_features)
        self.conv2 = EdgeConvRot(self.no_features, self.no_features, self.no_features)
        self.conv3 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)

        self.conv3_sub_pre = EdgeConvRot(self.no_features, self.no_features, self.no_features * 2)
        self.conv3_sub_pooling = SAGPooling(in_channels=2*self.no_features, GNN=GATConv)
        self.conv3_sub1 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features*2)
        self.conv3_sub2 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)

        # self.conv3_sub_pooling = SAGPooling(in_channels=self.no_features, GNN=GATConv)
        # self.conv3_sub_pre = EdgeConvRot(self.no_features, self.no_features, self.no_features * 2)
        self.conv3_subsub_pooling = SAGPooling(in_channels=2*self.no_features, GNN=GATConv)
        self.conv3_subsub = EdgeConvRot(2 * self.no_features, 2 * self.no_features, 2*self.no_features)
        self.conv3_subsub2 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, 2*self.no_features)

        self.conv4 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)
        self.conv5 = EdgeConvRot(2 * self.no_features, 2 * self.no_features, self.no_features)
        self.lin1 = Linear(self.no_features, num_opt)

        self.inplace = inplace

    def forward(self, node_feat, edge_index, edge_feat):

        x1, edge_x1 = self.conv1(node_feat, edge_index, edge_feat)
        x1 = F.relu(x1, inplace=self.inplace)
        edge_x1 = F.relu(edge_x1, inplace=self.inplace)

        x2, edge_x2 = self.conv2(x1, edge_index, edge_x1)
        x2 = F.relu(x2, inplace=self.inplace)
        edge_x2 = F.relu(edge_x2, inplace=self.inplace)

        x3, edge_x3 = self.conv3(torch.cat([x2, x1], dim=1), edge_index, torch.cat([edge_x2, edge_x1], dim=1))
        x3 = F.relu(x3, inplace=self.inplace)
        edge_x3 = F.relu(edge_x3, inplace=self.inplace)

        # Branch sub
        x3_s1, edge_x3_s1 = self.conv3_sub_pre(x3, edge_index, edge_x3)
        x3_s1, edge_x3_s1 = F.relu(x3_s1, inplace=self.inplace), F.relu(edge_x3_s1, inplace=self.inplace)

        x3_s1_pool, edge_s1_index_pool, edge_x3_s1_pool, batch, s1_to_ori, scores = self.conv3_sub_pooling.forward(
            x3_s1, edge_index, edge_attr=edge_x3_s1)
        x3_s2, edge_x3_s2 = self.conv3_sub1(x3_s1_pool, edge_s1_index_pool, edge_x3_s1_pool)
        x3_s2, edge_x3_s2 = F.relu(x3_s2, inplace=self.inplace), F.relu(edge_x3_s2, inplace=self.inplace)

        # Branch x3b
        x3_ss_pool, edge_ss_index_pool, edge_x3_ss_pool, batch1, ss_to_s1, _ = self.conv3_subsub_pooling(x3_s2,
                                                                                                         edge_s1_index_pool,
                                                                                                         edge_attr=edge_x3_s2)
        x3_ss1, edge_x3_ss1 = self.conv3_subsub.forward(x3_ss_pool, edge_ss_index_pool, edge_x3_ss_pool)
        x3_ss1, edge_x3_ss1 = F.relu(x3_ss1, inplace=self.inplace), F.relu(edge_x3_ss1, inplace=self.inplace)
        x3_ss2, edge_x3_ss2 = self.conv3_subsub.forward(x3_ss1, edge_ss_index_pool, edge_x3_ss1)
        x3_ss2, edge_x3_ss2 = F.relu(x3_ss2, inplace=self.inplace), F.relu(edge_x3_ss2, inplace=self.inplace)
        x3_s2_2 = x3_s2.clone()
        x3_s2_2[ss_to_s1] = x3_s2[ss_to_s1] + x3_ss2

        x3_s3, edge_x3_s3 = self.conv3_sub2(x3_s2_2, edge_s1_index_pool, edge_x3_s2)
        x3_s3, edge_x3_s3 = F.relu(x3_s3, inplace=self.inplace), F.relu(edge_x3_s3, inplace=self.inplace)
        # x3_s3_global = gmp(x3_s3, batch)
        x3_2 = x3.clone()
        x3_2[s1_to_ori] = x3[s1_to_ori] + x3_s3

        x4, edge_x4 = self.conv4(torch.cat([x3_2, x2], dim=1), edge_index, torch.cat([edge_x3, edge_x2], dim=1))
        x4, edge_x4 = F.relu(x4, inplace=self.inplace), F.relu(edge_x4, inplace=self.inplace)
        x5, edge_x5 = self.conv4(torch.cat([x4, x3_2], dim=1), edge_index, torch.cat([edge_x4, edge_x3], dim=1))
        x5, edge_x5 = F.relu(x5, inplace=self.inplace), F.relu(edge_x5, inplace=self.inplace)

        x = self.lin1(x5)

        return x
Example #4
0
class LevelFusion(nn.Module):
    def __init__(self, in_node_feat, in_edge_feat, inplace=True, num_opt=4):
        super(LevelFusion, self).__init__()
        self.no_features = 128
        self.reducer = nn.Linear(in_features=in_node_feat,
                                 out_features=self.no_features)
        self.conv1 = EdgeConvRot(self.no_features, in_edge_feat,
                                 self.no_features)
        self.conv2 = EdgeConvRot(self.no_features, self.no_features,
                                 self.no_features)
        self.conv3 = EdgeConvRot(2 * self.no_features, 2 * self.no_features,
                                 self.no_features)
        self.conv3b = EdgeConvRot(2 * self.no_features, 2 * self.no_features,
                                  self.no_features)

        self.conv3_sub_pre = EdgeConvRot(self.no_features, self.no_features,
                                         self.no_features * 2)
        self.conv3_sub_pooling = SAGPooling(in_channels=2 * self.no_features,
                                            GNN=GATConv)
        self.conv3_sub1 = EdgeConvRot(2 * self.no_features,
                                      2 * self.no_features, self.no_features)

        self.conv4 = EdgeConvRot(2 * self.no_features, self.no_features,
                                 self.no_features)
        # self.conv5 = EdgeConvRot(self.no_features, self.no_features, self.no_features)
        self.lin1 = Linear(self.no_features, 1)
        self.w_lin = Linear(self.no_features, 1)

        self.inplace = inplace

    def forward(self, node_feat, edge_feat, edge_index):

        node_feat_reduced = self.reducer.forward(node_feat)

        x1, edge_x1 = self.conv1(node_feat_reduced, edge_index, edge_feat)
        x1 = F.relu(x1, inplace=self.inplace)
        edge_x1 = F.relu(edge_x1, inplace=self.inplace)

        x2, edge_x2 = self.conv2(x1, edge_index, edge_x1)
        x2 = F.relu(x2, inplace=self.inplace)
        edge_x2 = F.relu(edge_x2, inplace=self.inplace)

        x3, edge_x3 = self.conv3(torch.cat([x2, x1], dim=1), edge_index,
                                 torch.cat([edge_x2, edge_x1], dim=1))
        x3 = F.relu(x3, inplace=self.inplace)
        edge_x3 = F.relu(edge_x3, inplace=self.inplace)

        x3, edge_x3 = self.conv3(torch.cat([x3, x2], dim=1), edge_index,
                                 torch.cat([edge_x3, edge_x2], dim=1))
        x3 = F.relu(x3, inplace=self.inplace)
        edge_x3 = F.relu(edge_x3, inplace=self.inplace)

        # Branch sub
        x3_s1, edge_x3_s1 = self.conv3_sub_pre(x3, edge_index, edge_x3)
        x3_s1, edge_x3_s1 = F.relu(x3_s1, inplace=self.inplace), F.relu(
            edge_x3_s1, inplace=self.inplace)

        x3_s1_pool, edge_s1_index_pool, edge_x3_s1_pool, batch, s1_to_ori, scores = self.conv3_sub_pooling.forward(
            x3_s1, edge_index, edge_attr=edge_x3_s1)
        x3_s2, edge_x3_s2 = self.conv3_sub1(x3_s1_pool, edge_s1_index_pool,
                                            edge_x3_s1_pool)
        x3_s2, edge_x3_s2 = F.relu(x3_s2, inplace=self.inplace), F.relu(
            edge_x3_s2, inplace=self.inplace)
        #
        x3_2 = x3.clone()
        x3_2[s1_to_ori] = x3[s1_to_ori] + x3_s2

        x4, edge_x4 = self.conv4(torch.cat([x3_2, x3], dim=1), edge_index,
                                 edge_x3)
        x4, edge_x4 = F.relu(x4, inplace=self.inplace), F.relu(
            edge_x4, inplace=self.inplace)
        # x5, edge_x5 = self.conv5(x4, edge_index, edge_x4)
        # x5, edge_x5 = F.relu(x5, inplace=self.inplace), F.relu(edge_x5, inplace=self.inplace)

        x = self.lin1(x4)
        x = torch.sigmoid(x)

        w = self.w_lin(edge_x4)
        w = torch.sigmoid(w)
        return x, w