Beispiel #1
0
    def forward(
        self,
        data_dict,
    ):
        images = data_dict['images']
        points = data_dict['Ps']
        n_points = data_dict['ns']
        graphs = data_dict['pyg_graphs']
        batch_size = data_dict['batch_size']
        num_graphs = len(images)

        if cfg.PROBLEM.TYPE == '2GM' and 'gt_perm_mat' in data_dict:
            gt_perm_mats = [data_dict['gt_perm_mat']]
        elif cfg.PROBLEM.TYPE == 'MGM' and 'gt_perm_mat' in data_dict:
            perm_mat_list = data_dict['gt_perm_mat']
            gt_perm_mats = [
                torch.bmm(pm_src, pm_tgt.transpose(1, 2))
                for pm_src, pm_tgt in lexico_iter(perm_mat_list)
            ]
        else:
            raise ValueError(
                'Ground truth information is required during training.')

        global_list = []
        orig_graph_list = []
        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            # extract feature
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            global_list.append(
                self.final_layers(edges).reshape((nodes.shape[0], -1)))
            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            # arrange features
            U = concat_features(feature_align(nodes, p, n_p, self.rescale),
                                n_p)
            F = concat_features(feature_align(edges, p, n_p, self.rescale),
                                n_p)
            node_features = torch.cat((U, F), dim=1)
            graph.x = node_features

            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(
                graph, hyperedge=True)
            orig_graph_list.append(orig_graph)

        global_weights_list = [
            torch.cat([global_src, global_tgt], axis=-1)
            for global_src, global_tgt in lexico_iter(global_list)
        ]

        global_weights_list = [
            normalize_over_channels(g) for g in global_weights_list
        ]

        unary_affs_list = [
            self.vertex_affinity([item.x for item in g_1],
                                 [item.x for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        quadratic_affs_list = [
            self.edge_affinity([item.edge_attr for item in g_1],
                               [item.edge_attr
                                for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        quadratic_affs_list = [[0.5 * x for x in quadratic_affs]
                               for quadratic_affs in quadratic_affs_list]

        order3_affs_list = [
            hyperedge_affinity([item.hyperedge_attr for item in g_1],
                               [item.hyperedge_attr for item in g_2])
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        s_list, mgm_s_list, x_list, mgm_x_list, indices = [], [], [], [], []

        for unary_affs, quadratic_affs, order3_affs, (g1, g2), (idx1, idx2) in \
            zip(unary_affs_list, quadratic_affs_list, order3_affs_list, lexico_iter(orig_graph_list), lexico_iter(range(num_graphs))):
            kro_G, kro_H = data_dict['KGHs'] if num_graphs == 2 else data_dict[
                'KGHs']['{},{}'.format(idx1, idx2)]
            Kp = torch.stack(pad_tensor(unary_affs), dim=0)
            Ke = torch.stack(pad_tensor(quadratic_affs), dim=0)
            He = torch.stack(pad_tensor(order3_affs))
            K = construct_aff_mat(Ke, Kp, kro_G, kro_H)

            # build hyper graph tensor H
            hyperE1, nmax1, emax1 = construct_hyperE(g1, batch_size, He.device)
            hyperE2, nmax2, emax2 = construct_hyperE(g2, batch_size, He.device)
            H = torch.bmm(torch.bmm(
                hyperE1.reshape(batch_size, -1, emax1), He), hyperE2.reshape(batch_size, -1, emax2).transpose(1, 2))\
                .reshape(batch_size, nmax1, nmax1, nmax1, nmax2, nmax2, nmax2).permute(0, 4, 1, 5, 2, 6, 3)\
                .reshape(batch_size, nmax1*nmax2, nmax1*nmax2, nmax1*nmax2)

            if num_graphs == 2: data_dict['aff_mat'] = K

            if cfg.NGM.FIRST_ORDER:
                emb = Kp.transpose(1, 2).contiguous().view(Kp.shape[0], -1, 1)
            else:
                emb = torch.ones(K.shape[0], K.shape[1], 1, device=K.device)

            if cfg.NGM.POSITIVE_EDGES:
                adjs = [(K > 0).to(K.dtype), (H > 0).to(H.dtype)]
            else:
                adjs = [(K != 0).to(K.dtype), (H != 0).to(H.dtype)]

            emb_edges = [
                K.unsqueeze(-1),
                to_sparse(H.unsqueeze(-1), dense_dim=2)
            ]

            # NGM qap solver
            for i in range(self.gnn_layer):
                gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
                emb_edges, emb = gnn_layer(adjs, emb_edges, emb,
                                           n_points[idx1],
                                           n_points[idx2])  #, weight=[1, 0.1])

            v = self.classifier(emb)
            s = v.view(v.shape[0], points[idx2].shape[1], -1).transpose(1, 2)

            ss = self.sinkhorn(s,
                               n_points[idx1],
                               n_points[idx2],
                               dummy_row=True)
            x = hungarian(ss, n_points[idx1], n_points[idx2])
            s_list.append(ss)
            x_list.append(x)
            indices.append((idx1, idx2))

        if num_graphs > 2:
            joint_indices = torch.cat(
                (torch.cumsum(torch.stack([torch.max(np) for np in n_points]),
                              dim=0),
                 torch.zeros((1, ), dtype=torch.long, device=K.device)))
            joint_S = torch.zeros(batch_size,
                                  torch.max(joint_indices),
                                  torch.max(joint_indices),
                                  device=K.device)
            for idx in range(num_graphs):
                for b in range(batch_size):
                    start = joint_indices[idx - 1]
                    joint_S[b, start:start + n_points[idx][b], start:start +
                            n_points[idx][b]] += torch.eye(n_points[idx][b],
                                                           device=K.device)

            for (idx1, idx2), s in zip(indices, s_list):
                if idx1 > idx2:
                    joint_S[:, joint_indices[idx2 - 1]:joint_indices[idx2],
                            joint_indices[idx1 - 1]:
                            joint_indices[idx1]] += s.transpose(1, 2)
                else:
                    joint_S[:, joint_indices[idx1 - 1]:joint_indices[idx1],
                            joint_indices[idx2 - 1]:joint_indices[idx2]] += s

            matching_s = []
            for b in range(batch_size):
                e, v = torch.symeig(joint_S[b], eigenvectors=True)
                diff = e[-self.univ_size:-1] - e[-self.univ_size + 1:]
                if self.training and torch.min(torch.abs(diff)) <= 1e-4:
                    matching_s.append(joint_S[b])
                else:
                    matching_s.append(
                        num_graphs *
                        torch.mm(v[:, -self.univ_size:],
                                 v[:, -self.univ_size:].transpose(0, 1)))

            matching_s = torch.stack(matching_s, dim=0)

            for idx1, idx2 in indices:
                s = matching_s[:, joint_indices[idx1 - 1]:joint_indices[idx1],
                               joint_indices[idx2 - 1]:joint_indices[idx2]]
                s = self.sinkhorn_mgm(
                    torch.log(torch.relu(s)), n_points[idx1], n_points[idx2]
                )  # only perform row/col norm, do not perform exp
                x = hungarian(s, n_points[idx1], n_points[idx2])

                mgm_s_list.append(s)
                mgm_x_list.append(x)

        if cfg.PROBLEM.TYPE == '2GM':
            data_dict.update({'ds_mat': s_list[0], 'perm_mat': x_list[0]})
        elif cfg.PROBLEM.TYPE == 'MGM':
            data_dict.update({
                'ds_mat_list': mgm_s_list,
                'perm_mat_list': mgm_x_list,
                'graph_indices': indices,
                'gt_perm_mat_list': gt_perm_mats
            })

        return data_dict
Beispiel #2
0
    def forward(self, data_dict, **kwargs):
        # extract graph feature
        if 'images' in data_dict:
            # extract data
            data = data_dict['images']
            Ps = data_dict['Ps']
            ns = data_dict['ns']
            Gs = data_dict['Gs']
            Hs = data_dict['Hs']
            Gs_tgt = data_dict['Gs_tgt']
            Hs_tgt = data_dict['Hs_tgt']
            KGs = {k: v[0] for k, v in data_dict['KGHs'].items()}
            KHs = {k: v[1] for k, v in data_dict['KGHs'].items()}

            batch_size = data[0].shape[0]
            device = data[0].device

            data_cat = torch.cat(data, dim=0)
            P_cat = torch.cat(pad_tensor(Ps), dim=0)
            n_cat = torch.cat(ns, dim=0)
            node = self.node_layers(data_cat)
            edge = self.edge_layers(node)
            U = feature_align(node, P_cat, n_cat, self.rescale)
            F = feature_align(edge, P_cat, n_cat, self.rescale)
            feats = torch.cat((U, F), dim=1)
            feats = self.l2norm(feats)
            feats = torch.split(feats, batch_size, dim=0)
        elif 'features' in data_dict:
            # extract data
            data = data_dict['features']
            Ps = data_dict['Ps']
            ns = data_dict['ns']
            Gs = data_dict['Gs']
            Hs = data_dict['Hs']
            Gs_tgt = data_dict['Gs_tgt']
            Hs_tgt = data_dict['Hs_tgt']
            KGs = {k: v[0] for k, v in data_dict['KGHs'].items()}
            KHs = {k: v[1] for k, v in data_dict['KGHs'].items()}

            batch_size = data[0].shape[0]
            device = data[0].device

            feats = data
        else:
            raise ValueError('Unknown data type for this model.')

        # extract reference graph feature
        feat_list = []
        joint_indices = [0]
        iterator = zip(feats, Ps, Gs, Hs, Gs_tgt, Hs_tgt, ns)
        for idx, (feat, P, G, H, G_tgt, H_tgt, n) in enumerate(iterator):
            feat_list.append((idx, feat, P, G, H, G_tgt, H_tgt, n))
            joint_indices.append(joint_indices[-1] + P.shape[1])

        joint_S = torch.zeros(batch_size,
                              joint_indices[-1],
                              joint_indices[-1],
                              device=device)
        joint_S_diag = torch.diagonal(joint_S, dim1=1, dim2=2)
        joint_S_diag += 1

        pred_s = []
        pred_x = []
        indices = []

        for src, tgt in combinations(feat_list, 2):
            # pca forward
            src_idx, src_feat, P_src, G_src, H_src, _, __, n_src = src
            tgt_idx, tgt_feat, P_tgt, _, __, G_tgt, H_tgt, n_tgt = tgt
            K_G = KGs['{},{}'.format(src_idx, tgt_idx)]
            K_H = KHs['{},{}'.format(src_idx, tgt_idx)]
            s = self.__ngm_forward(src_feat, tgt_feat, P_src, P_tgt, G_src,
                                   G_tgt, H_src, H_tgt, K_G, K_H, n_src, n_tgt)

            if src_idx > tgt_idx:
                joint_S[:, joint_indices[tgt_idx]:joint_indices[tgt_idx + 1],
                        joint_indices[src_idx]:joint_indices[
                            src_idx + 1]] += s.transpose(1, 2)
            else:
                joint_S[:, joint_indices[src_idx]:joint_indices[src_idx + 1],
                        joint_indices[tgt_idx]:joint_indices[tgt_idx + 1]] += s

        matching_s = []
        for b in range(batch_size):
            e, v = torch.symeig(joint_S[b], eigenvectors=True)
            topargs = torch.argsort(torch.abs(e),
                                    descending=True)[:joint_indices[1]]
            diff = e[topargs[:-1]] - e[topargs[1:]]
            if torch.min(torch.abs(diff)) > 1e-4:
                matching_s.append(
                    len(data) *
                    torch.mm(v[:, topargs], v[:, topargs].transpose(0, 1)))
            else:
                matching_s.append(joint_S[b])

        matching_s = torch.stack(matching_s, dim=0)

        for idx1, idx2 in combinations(range(len(data)), 2):
            s = matching_s[:, joint_indices[idx1]:joint_indices[idx1 + 1],
                           joint_indices[idx2]:joint_indices[idx2 + 1]]
            s = self.sinkhorn2(s)

            pred_s.append(s)
            pred_x.append(hungarian(s))
            indices.append((idx1, idx2))

        data_dict.update({
            'ds_mat_list': pred_s,
            'perm_mat_list': pred_x,
            'graph_indices': indices,
        })
        return data_dict
Beispiel #3
0
    def forward(self, data_dict, **kwargs):
        if 'images' in data_dict:
            # real image data
            src, tgt = data_dict['images']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, self.rescale)
            F_src = feature_align(src_edge, P_src, ns_src, self.rescale)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, self.rescale)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, self.rescale)
        elif 'features' in data_dict:
            # synthetic data
            src, tgt = data_dict['features']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']

            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('Unknown data type for this model.')

        P_src_dis = (P_src.unsqueeze(1) - P_src.unsqueeze(2))
        P_src_dis = torch.norm(P_src_dis, p=2, dim=3).detach()
        P_tgt_dis = (P_tgt.unsqueeze(1) - P_tgt.unsqueeze(2))
        P_tgt_dis = torch.norm(P_tgt_dis, p=2, dim=3).detach()

        Q_src = torch.exp(-P_src_dis / self.rescale[0])
        Q_tgt = torch.exp(-P_tgt_dis / self.rescale[0])

        emb_edge1 = Q_src.unsqueeze(-1)
        emb_edge2 = Q_tgt.unsqueeze(-1)

        # adjacency matrices
        A_src = torch.bmm(G_src, H_src.transpose(1, 2))
        A_tgt = torch.bmm(G_tgt, H_tgt.transpose(1, 2))

        # U_src, F_src are features at different scales
        emb1, emb2 = torch.cat((U_src, F_src), dim=1).transpose(1, 2), torch.cat((U_tgt, F_tgt), dim=1).transpose(1, 2)
        ss = []

        for i in range(self.gnn_layer):
            gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))

            # during forward process, the network structure will not change
            emb1, emb2, emb_edge1, emb_edge2 = gnn_layer([A_src, emb1, emb_edge1], [A_tgt, emb2, emb_edge2])

            affinity = getattr(self, 'affinity_{}'.format(i))
            s = affinity(emb1, emb2) # xAx^T

            s = self.sinkhorn(s, ns_src, ns_tgt)
            ss.append(s)

            if i == self.gnn_layer - 2:
                cross_graph = getattr(self, 'cross_graph_{}'.format(i))
                new_emb1 = cross_graph(torch.cat((emb1, torch.bmm(s, emb2)), dim=-1))
                new_emb2 = cross_graph(torch.cat((emb2, torch.bmm(s.transpose(1, 2), emb1)), dim=-1))
                emb1 = new_emb1
                emb2 = new_emb2

                # edge cross embedding
                '''
                cross_graph_edge = getattr(self, 'cross_graph_edge_{}'.format(i))
                emb_edge1 = emb_edge1.permute(0, 3, 1, 2)
                emb_edge2 = emb_edge2.permute(0, 3, 1, 2)
                s = s.unsqueeze(1)
                new_emb_edge1 = cross_graph_edge(torch.cat((emb_edge1, torch.matmul(torch.matmul(s, emb_edge2), s.transpose(2, 3))), dim=1).permute(0, 2, 3, 1))
                new_emb_edge2 = cross_graph_edge(torch.cat((emb_edge2, torch.matmul(torch.matmul(s.transpose(2, 3), emb_edge1), s)), dim=1).permute(0, 2, 3, 1))
                emb_edge1 = new_emb_edge1
                emb_edge2 = new_emb_edge2
                '''

        data_dict.update({
            'ds_mat': ss[-1],
            'perm_mat': hungarian(ss[-1], ns_src, ns_tgt)
        })
        return data_dict
Beispiel #4
0
    def forward(
        self,
        data_dict,
    ):
        images = data_dict['images']
        points = data_dict['Ps']
        n_points = data_dict['ns']
        graphs = data_dict['pyg_graphs']
        num_graphs = len(images)

        if cfg.PROBLEM.TYPE == '2GM' and 'gt_perm_mat' in data_dict:
            gt_perm_mats = [data_dict['gt_perm_mat']]
        elif cfg.PROBLEM.TYPE == 'MGM' and 'gt_perm_mat' in data_dict:
            gt_perm_mats = data_dict['gt_perm_mat'].values()
        else:
            raise ValueError(
                'Ground truth information is required during training.')

        global_list = []
        orig_graph_list = []
        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            # extract feature
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            global_list.append(
                self.final_layers(edges).reshape((nodes.shape[0], -1)))
            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            # arrange features
            U = concat_features(feature_align(nodes, p, n_p, self.rescale),
                                n_p)
            F = concat_features(feature_align(edges, p, n_p, self.rescale),
                                n_p)
            node_features = torch.cat((U, F), dim=1)
            graph.x = node_features

            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(graph)
            orig_graph_list.append(orig_graph)

        global_weights_list = [
            torch.cat([global_src, global_tgt], axis=-1)
            for global_src, global_tgt in lexico_iter(global_list)
        ]

        global_weights_list = [
            normalize_over_channels(g) for g in global_weights_list
        ]

        unary_costs_list = [
            self.vertex_affinity([item.x for item in g_1],
                                 [item.x for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Similarities to costs
        unary_costs_list = [[-x for x in unary_costs]
                            for unary_costs in unary_costs_list]

        if self.training:
            unary_costs_list = [
                [
                    x +
                    1.0 * gt[:dim_src, :dim_tgt]  # Add margin with alpha = 1.0
                    for x, gt, dim_src, dim_tgt in zip(
                        unary_costs, gt_perm_mat, ns_src, ns_tgt)
                ] for unary_costs, gt_perm_mat, (ns_src, ns_tgt) in zip(
                    unary_costs_list, gt_perm_mats, lexico_iter(n_points))
            ]

        quadratic_costs_list = [
            self.edge_affinity([item.edge_attr for item in g_1],
                               [item.edge_attr
                                for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Similarities to costs
        quadratic_costs_list = [[-0.5 * x for x in quadratic_costs]
                                for quadratic_costs in quadratic_costs_list]

        if cfg.BBGM.SOLVER_NAME == "LPMP":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solvers = [
                GraphMatchingModule(
                    all_left_edges,
                    all_right_edges,
                    ns_src,
                    ns_tgt,
                    cfg.BBGM.LAMBDA_VAL,
                    cfg.BBGM.SOLVER_PARAMS,
                )
                for (all_left_edges, all_right_edges), (ns_src, ns_tgt) in zip(
                    lexico_iter(all_edges), lexico_iter(n_points))
            ]
            matchings = [
                gm_solver(unary_costs, quadratic_costs)
                for gm_solver, unary_costs, quadratic_costs in zip(
                    gm_solvers, unary_costs_list, quadratic_costs_list)
            ]
        elif cfg.BBGM.SOLVER_NAME == "LPMP_MGM":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solver = MultiGraphMatchingModule(all_edges, n_points,
                                                 cfg.BBGM.LAMBDA_VAL,
                                                 cfg.BBGM.SOLVER_PARAMS)
            matchings = gm_solver(unary_costs_list, quadratic_costs_list)
        else:
            raise ValueError("Unknown solver {}".format(cfg.BBGM.SOLVER_NAME))

        if cfg.PROBLEM.TYPE == '2GM':
            data_dict.update({'ds_mat': None, 'perm_mat': matchings[0]})
        elif cfg.PROBLEM.TYPE == 'MGM':
            indices = list(lexico_iter(range(num_graphs)))
            data_dict.update({
                'perm_mat_list': matchings,
                'graph_indices': indices,
            })

        return data_dict
Beispiel #5
0
    def forward(self, data_dict, **kwargs):
        if 'images' in data_dict:
            # real image data
            src, tgt = data_dict['images']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, self.rescale)
            F_src = feature_align(src_edge, P_src, ns_src, self.rescale)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, self.rescale)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, self.rescale)
        elif 'features' in data_dict:
            # synthetic data
            src, tgt = data_dict['features']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('Unknown data type for this model.')

        X = reshape_edge_feature(F_src, G_src, H_src)
        Y = reshape_edge_feature(F_tgt, G_tgt, H_tgt)

        # affinity layer
        Me, Mp = self.affinity_layer(X, Y, U_src, U_tgt)

        M = construct_aff_mat(Me, Mp, K_G, K_H)

        v = self.gm_solver(M, num_src=P_src.shape[1], ns_src=ns_src, ns_tgt=ns_tgt)
        s = v.view(v.shape[0], P_tgt.shape[1], -1).transpose(1, 2)

        s = self.sinkhorn(s, ns_src, ns_tgt)

        data_dict.update({
            'ds_mat': s,
            'perm_mat': hungarian(s, ns_src, ns_tgt),
            'aff_mat': M
        })
        return data_dict
Beispiel #6
0
    def forward(self, data_dict):
        if 'images' in data_dict:
            # real image data
            src, tgt = data_dict['images']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, self.rescale)
            F_src = feature_align(src_edge, P_src, ns_src, self.rescale)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, self.rescale)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, self.rescale)
        elif 'features' in data_dict:
            # synthetic data
            src, tgt = data_dict['features']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('unknown type string {}'.format(type))

        X = reshape_edge_feature(F_src, G_src, H_src)
        Y = reshape_edge_feature(F_tgt, G_tgt, H_tgt)
        dx = geo_edge_feature(P_src, G_src, H_src)[:, :1, :]
        dy = geo_edge_feature(P_tgt, G_tgt, H_tgt)[:, :1, :]

        # affinity layer for 2-order affinity matrix
        if cfg.NGM.EDGE_FEATURE == 'cat':
            Ke, Kp = self.feat_affinity_layer(X, Y, U_src, U_tgt)
        elif cfg.NGM.EDGE_FEATURE == 'geo':
            Ke, Kp = self.geo_affinity_layer(dx, dy, U_src, U_tgt)
        else:
            raise ValueError('Unknown edge feature type {}'.format(
                cfg.NGM.EDGE_FEATURE))

        K = construct_aff_mat(Ke, torch.zeros_like(Kp), K_G, K_H)
        adj = (K > 0).to(K.dtype)

        # build 3-order affinity tensor
        hshape = list(adj.shape) + [adj.shape[-1]]
        order3A = adj.unsqueeze(1).expand(hshape) * adj.unsqueeze(2).expand(
            hshape) * adj.unsqueeze(3).expand(hshape)
        hyper_adj = order3A

        if cfg.NGM.ORDER3_FEATURE == 'cat':
            Ke_3, _ = self.feat_affinity_layer3(X,
                                                Y,
                                                torch.zeros(1, 1, 1),
                                                torch.zeros(1, 1, 1),
                                                w1=0.5,
                                                w2=1)
            K_3 = construct_aff_mat(Ke_3, torch.zeros_like(Kp), K_G, K_H)
            H = (K_3.unsqueeze(1).expand(hshape) +
                 K_3.unsqueeze(2).expand(hshape) +
                 K_3.unsqueeze(3).expand(hshape)) * F.relu(self.weight3)
        elif cfg.NGM.ORDER3_FEATURE == 'geo':
            Ke_d, _ = self.geo_affinity_layer(dx, dy, torch.zeros(1, 1, 1),
                                              torch.zeros(1, 1, 1))

            m_d_src = construct_aff_mat(
                dx.squeeze().unsqueeze(-1).expand_as(Ke_d),
                torch.zeros_like(Kp), K_G, K_H).cpu()
            m_d_tgt = construct_aff_mat(
                dy.squeeze().unsqueeze(-2).expand_as(Ke_d),
                torch.zeros_like(Kp), K_G, K_H).cpu()
            order3A = order3A.cpu()

            cum_sin = torch.zeros_like(order3A)
            for i in range(3):

                def calc_sin(t):
                    a = t.unsqueeze(i % 3 + 1).expand(hshape)
                    b = t.unsqueeze((i + 1) % 3 + 1).expand(hshape)
                    c = t.unsqueeze((i + 2) % 3 + 1).expand(hshape)
                    cos = torch.clamp(
                        (a.pow(2) + b.pow(2) - c.pow(2)) / (2 * a * b + 1e-15),
                        -1, 1)
                    cos *= order3A
                    sin = torch.sqrt(1 - cos.pow(2)) * order3A
                    assert torch.sum(torch.isnan(sin)) == 0
                    return sin

                sin_src = calc_sin(m_d_src)
                sin_tgt = calc_sin(m_d_tgt)
                cum_sin += torch.abs(sin_src - sin_tgt)

            H = torch.exp(-1 / cfg.NGM.SIGMA3 * cum_sin) * order3A
            H = H.cuda()
            order3A = order3A.cuda()
        elif cfg.NGM.ORDER3_FEATURE == 'none':
            H = torch.zeros_like(hyper_adj)
        else:
            raise ValueError('Unknown edge feature type {}'.format(
                cfg.NGM.ORDER3_FEATURE))

        hyper_adj = hyper_adj.cpu()
        hyper_adj_sum = torch.sum(
            hyper_adj, dim=tuple(range(2, 3 + 1)), keepdim=True) + 1e-10
        hyper_adj = hyper_adj / hyper_adj_sum
        hyper_adj = hyper_adj.to_sparse().coalesce().cuda()

        H = H.sparse_mask(hyper_adj)
        H = (H._indices(), H._values().unsqueeze(-1))

        if cfg.NGM.FIRST_ORDER:
            emb = Kp.transpose(1, 2).contiguous().view(Kp.shape[0], -1, 1)
        else:
            emb = torch.ones(K.shape[0], K.shape[1], 1, device=K.device)

        adj_sum = torch.sum(adj, dim=2, keepdim=True) + 1e-10
        adj = adj / adj_sum
        pack_M = [K.unsqueeze(-1), H]
        pack_A = [adj, hyper_adj]
        for i in range(self.gnn_layer):
            gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
            pack_M, emb = gnn_layer(pack_A,
                                    pack_M,
                                    emb,
                                    ns_src,
                                    ns_tgt,
                                    norm=False)

        v = self.classifier(emb)
        s = v.view(v.shape[0], P_tgt.shape[1], -1).transpose(1, 2)

        ss = self.bi_stochastic(s, ns_src, ns_tgt)
        x = hungarian(ss, ns_src, ns_tgt)

        data_dict.update({'ds_mat': ss, 'perm_mat': x, 'aff_mat': K})

        return data_dict
Beispiel #7
0
    def forward(self, data_dict, **kwargs):
        batch_size = data_dict['batch_size']
        if 'images' in data_dict:
            # real image data
            src, tgt = data_dict['images']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, self.rescale)
            F_src = feature_align(src_edge, P_src, ns_src, self.rescale)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, self.rescale)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, self.rescale)
        elif 'features' in data_dict:
            # synthetic data
            src, tgt = data_dict['features']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            G_src, G_tgt = data_dict['Gs']
            H_src, H_tgt = data_dict['Hs']
            K_G, K_H = data_dict['KGHs']

            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        elif 'aff_mat' in data_dict:
            K = data_dict['aff_mat']
            ns_src, ns_tgt = data_dict['ns']
        else:
            raise ValueError('Unknown data type for this model.')

        if 'images' in data_dict or 'features' in data_dict:
            tgt_len = P_tgt.shape[1]
            if cfg.NGM.EDGE_FEATURE == 'cat':
                X = reshape_edge_feature(F_src, G_src, H_src)
                Y = reshape_edge_feature(F_tgt, G_tgt, H_tgt)
            elif cfg.NGM.EDGE_FEATURE == 'geo':
                X = geo_edge_feature(P_src, G_src, H_src)[:, :1, :]
                Y = geo_edge_feature(P_tgt, G_tgt, H_tgt)[:, :1, :]
            else:
                raise ValueError('Unknown edge feature type {}'.format(
                    cfg.NGM.EDGE_FEATURE))

            # affinity layer
            Ke, Kp = self.affinity_layer(X, Y, U_src, U_tgt)

            K = construct_aff_mat(Ke, torch.zeros_like(Kp), K_G, K_H)

            A = (K > 0).to(K.dtype)

            if cfg.NGM.FIRST_ORDER:
                emb = Kp.transpose(1, 2).contiguous().view(Kp.shape[0], -1, 1)
            else:
                emb = torch.ones(K.shape[0], K.shape[1], 1, device=K.device)
        else:
            tgt_len = int(math.sqrt(K.shape[2]))
            dmax = (torch.max(torch.sum(K, dim=2, keepdim=True),
                              dim=1,
                              keepdim=True).values + 1e-5)
            K = K / dmax * 1000
            A = (K > 0).to(K.dtype)
            emb = torch.ones(K.shape[0], K.shape[1], 1, device=K.device)

        emb_K = K.unsqueeze(-1)

        # NGM qap solver
        for i in range(self.gnn_layer):
            gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
            emb_K, emb = gnn_layer(A, emb_K, emb, ns_src,
                                   ns_tgt)  #, norm=False)

        v = self.classifier(emb)
        s = v.view(v.shape[0], tgt_len, -1).transpose(1, 2)

        if self.training or cfg.NGM.GUMBEL_SK <= 0:
            #if cfg.NGM.GUMBEL_SK <= 0:
            ss = self.sinkhorn(s, ns_src, ns_tgt, dummy_row=True)
            x = hungarian(ss, ns_src, ns_tgt)
        else:
            gumbel_sample_num = cfg.NGM.GUMBEL_SK
            if self.training:
                gumbel_sample_num //= 10
            ss_gumbel = self.gumbel_sinkhorn(s,
                                             ns_src,
                                             ns_tgt,
                                             sample_num=gumbel_sample_num,
                                             dummy_row=True)

            repeat = lambda x, rep_num=gumbel_sample_num: torch.repeat_interleave(
                x, rep_num, dim=0)
            if not self.training:
                ss_gumbel = hungarian(ss_gumbel, repeat(ns_src),
                                      repeat(ns_tgt))
            ss_gumbel = ss_gumbel.reshape(batch_size, gumbel_sample_num,
                                          ss_gumbel.shape[-2],
                                          ss_gumbel.shape[-1])

            if ss_gumbel.device.type == 'cuda':
                dev_idx = ss_gumbel.device.index
                free_mem = gpu_free_memory(
                    dev_idx
                ) - 100 * 1024**2  # 100MB as buffer for other computations
                K_mem_size = K.element_size() * K.nelement()
                max_repeats = free_mem // K_mem_size
                if max_repeats <= 0:
                    print('Warning: GPU may not have enough memory')
                    max_repeats = 1
            else:
                max_repeats = gumbel_sample_num

            obj_score = []
            for idx in range(0, gumbel_sample_num, max_repeats):
                if idx + max_repeats > gumbel_sample_num:
                    rep_num = gumbel_sample_num - idx
                else:
                    rep_num = max_repeats
                obj_score.append(
                    objective_score(
                        ss_gumbel[:, idx:(idx + rep_num), :, :].reshape(
                            -1, ss_gumbel.shape[-2], ss_gumbel.shape[-1]),
                        repeat(K, rep_num)).reshape(batch_size, -1))
            obj_score = torch.cat(obj_score, dim=1)
            min_obj_score = obj_score.min(dim=1)
            ss = ss_gumbel[torch.arange(batch_size),
                           min_obj_score.indices.cpu(), :, :]
            x = hungarian(ss, repeat(ns_src), repeat(ns_tgt))

        data_dict.update({'ds_mat': ss, 'perm_mat': x, 'aff_mat': K})
        return data_dict
Beispiel #8
0
    def forward(self, data_dict, **kwargs):
        if 'images' in data_dict:
            # real image data
            src, tgt = data_dict['images']
            P_src, P_tgt = data_dict['Ps']
            ns_src, ns_tgt = data_dict['ns']
            A_src, A_tgt = data_dict['As']

            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, self.rescale)
            F_src = feature_align(src_edge, P_src, ns_src, self.rescale)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, self.rescale)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, self.rescale)
        elif 'features' in data_dict:
            # synthetic data
            src, tgt = data_dict['features']
            ns_src, ns_tgt = data_dict['ns']
            A_src, A_tgt = data_dict['As']

            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('Unknown data type for this model.')

        emb1, emb2 = torch.cat((U_src, F_src),
                               dim=1).transpose(1, 2), torch.cat(
                                   (U_tgt, F_tgt), dim=1).transpose(1, 2)
        ss = []

        if not self.cross_iter:
            # Vanilla PCA-GM
            for i in range(self.gnn_layer):
                gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
                emb1, emb2 = gnn_layer([A_src, emb1], [A_tgt, emb2])
                affinity = getattr(self, 'affinity_{}'.format(i))
                s = affinity(emb1, emb2)
                s = self.sinkhorn(s, ns_src, ns_tgt, dummy_row=True)

                ss.append(s)

                if i == self.gnn_layer - 2:
                    cross_graph = getattr(self, 'cross_graph_{}'.format(i))
                    new_emb1 = cross_graph(
                        torch.cat((emb1, torch.bmm(s, emb2)), dim=-1))
                    new_emb2 = cross_graph(
                        torch.cat((emb2, torch.bmm(s.transpose(1, 2), emb1)),
                                  dim=-1))
                    emb1 = new_emb1
                    emb2 = new_emb2
        else:
            # IPCA-GM
            for i in range(self.gnn_layer - 1):
                gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
                emb1, emb2 = gnn_layer([A_src, emb1], [A_tgt, emb2])

            emb1_0, emb2_0 = emb1, emb2
            s = torch.zeros(emb1.shape[0],
                            emb1.shape[1],
                            emb2.shape[1],
                            device=emb1.device)

            for x in range(self.cross_iter_num):
                i = self.gnn_layer - 2
                cross_graph = getattr(self, 'cross_graph_{}'.format(i))
                emb1 = cross_graph(
                    torch.cat((emb1_0, torch.bmm(s, emb2_0)), dim=-1))
                emb2 = cross_graph(
                    torch.cat((emb2_0, torch.bmm(s.transpose(1, 2), emb1_0)),
                              dim=-1))

                i = self.gnn_layer - 1
                gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
                emb1, emb2 = gnn_layer([A_src, emb1], [A_tgt, emb2])
                affinity = getattr(self, 'affinity_{}'.format(i))
                s = affinity(emb1, emb2)
                s = self.sinkhorn(s, ns_src, ns_tgt, dummy_row=True)
                ss.append(s)

        data_dict.update({
            'ds_mat': ss[-1],
            'perm_mat': hungarian(ss[-1], ns_src, ns_tgt)
        })
        return data_dict