def forward(
            self,
            images,
            points,
            graphs,
            n_points,
            perm_mats,
            visualize_flag=False,
            visualization_params=None,
    ):
        global_list = []
        orig_graph_list = []
        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            global_list.append(self.final_layers(edges)[0].reshape((nodes.shape[0], -1)))
            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            # arrange features
            U = concat_features(feature_align(nodes, p, n_p, (256, 256)), n_p)
            F = concat_features(feature_align(edges, p, n_p, (256, 256)), n_p)
            node_features = torch.cat((U, F), dim=-1)
            node_features = self.reduce_feat_layer(node_features)

            graph.x = node_features

            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(graph)
            orig_graph_list.append(orig_graph)

        return orig_graph_list, global_list
Esempio n. 2
0
    def forward(self, src, tgt, P_src, P_tgt, ns_src, ns_tgt):
        """
        FORWARD ROUTINE
        """

        # (A) Extract node and edge features
        src_node = self.node_layers(src)
        src_edge = self.edge_layers(src_node)
        tgt_node = self.node_layers(tgt)
        tgt_edge = self.edge_layers(tgt_node)

        # (B) Feature Normalization
        src_node = self.l2norm(src_node)
        src_edge = self.l2norm(src_edge)
        tgt_node = self.l2norm(tgt_node)
        tgt_edge = self.l2norm(tgt_edge)

        # (C) FEATURE ARRANGEMENT
        U_src = feature_align(src_node, P_src, ns_src, cfg.PAIR.RESCALE)
        F_src = feature_align(src_edge, P_src, ns_src, cfg.PAIR.RESCALE)
        U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
        F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, cfg.PAIR.RESCALE)

        # (D) Graph Learning on source image
        A_src = self.graph_learning(U_src)

        # (D-1) Optional output debugging of learned graph
        #A_src_ = A_src.clone().detach()
        #A_src_[A_src_==0] = 100
        #print("A_src: \nMin:  ", torch.min(torch.min(A_src_, dim=1).values, dim=1).values.data, ",\nMax:  ", torch.max(torch.max(A_src, dim=1).values, dim=1).values.data, "\nMean: ", torch.mean(A_src, dim=(1,2)))

        # (E) Graph Learning on target image
        A_tgt = self.graph_learning(
            U_tgt)  #Shared weights, so pass through the same networki

        # (E-1) Optional output debugging of learned graph
        #A_tgt_ = A_tgt.clone().detach()
        #A_tgt_[A_tgt_==0] = 100;
        #print("A_tgt: \nMin:  ", torch.min(torch.min(A_tgt_, dim=1).values, dim=1).values.data, ",\nMax:  ", torch.max(torch.max(A_tgt, dim=1).values, dim=1).values.data, "\nMean: ", torch.mean(A_tgt, dim=(1,2)))
        #print("---")

        # (F) Compute Affinity Matrix M
        M = self.affinity_layer(A_src, A_tgt, F_src, F_tgt, U_src, U_tgt)
        #M = self.affinity_layer(G1, H1 G2, H2, F_src, F_tgt, U_src, U_tgt)

        # (G) Compute (optimal) assignment vector using power iterations
        v = self.power_iteration(M)
        s = v.view(v.shape[0], P_tgt.shape[1], -1).transpose(1, 2)

        # (H) Apply voting and bi-stochastic layer
        s = self.voting_layer(s, ns_src, ns_tgt)
        s = self.bi_stochastic(s, ns_src, ns_tgt)

        # (I) Compute displacement
        d, _ = self.displacement_layer(s, P_src, P_tgt)

        return s, d
Esempio n. 3
0
    def forward(self,
                src,
                tgt,
                P_src,
                P_tgt,
                G_src,
                G_tgt,
                H_src,
                H_tgt,
                ns_src,
                ns_tgt,
                K_G,
                K_H,
                type='img'):
        if type == 'img' or type == 'image':
            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, cfg.PAIR.RESCALE)
            F_src = feature_align(src_edge, P_src, ns_src, cfg.PAIR.RESCALE)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
        elif type == 'feat' or type == 'feature':
            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('unknown type string {}'.format(type))

        X = reshape_edge_feature(F_src, G_src, H_src)
        Y = reshape_edge_feature(F_tgt, G_tgt, H_tgt)

        # affinity layer
        Me, Mp = self.affinity_layer(X, Y, U_src, U_tgt)

        M = construct_m(Me, Mp, K_G, K_H)

        v = self.power_iteration(M)
        s = v.view(v.shape[0], P_tgt.shape[1], -1).transpose(1, 2)

        s = self.voting_layer(s, ns_src, ns_tgt)
        s = self.bi_stochastic(s, ns_src, ns_tgt)

        d, _ = self.displacement_layer(s, P_src, P_tgt)
        return s, d
Esempio n. 4
0
    def forward(self,
                src,
                tgt,
                P_src,
                P_tgt,
                G_src,
                G_tgt,
                H_src,
                H_tgt,
                ns_src,
                ns_tgt,
                K_G,
                K_H,
                type='img'):
        if type == 'img' or type == 'image':
            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, cfg.PAIR.RESCALE)
            F_src = feature_align(src_edge, P_src, ns_src, cfg.PAIR.RESCALE)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
        elif type == 'feat' or type == 'feature':
            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('unknown type string {}'.format(type))

        # adjacency matrices
        A_src = torch.bmm(G_src, H_src.transpose(1, 2))
        A_tgt = torch.bmm(G_tgt, H_tgt.transpose(1, 2))

        emb1, emb2 = torch.cat((U_src, F_src),
                               dim=1).transpose(1, 2), torch.cat(
                                   (U_tgt, F_tgt), dim=1).transpose(1, 2)

        for i in range(self.gnn_layer):
            gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
            emb1, emb2 = gnn_layer([A_src, emb1], [A_tgt, emb2])
            affinity = getattr(self, 'affinity_{}'.format(i))
            s = affinity(emb1, emb2)
            s = self.voting_layer(s, ns_src, ns_tgt)
            s = self.bi_stochastic(s, ns_src, ns_tgt)

            if i == self.gnn_layer - 2:
                cross_graph = getattr(self, 'cross_graph_{}'.format(i))
                emb1_new = cross_graph(
                    torch.cat((emb1, torch.bmm(s, emb2)), dim=-1))
                emb2_new = cross_graph(
                    torch.cat((emb2, torch.bmm(s.transpose(1, 2), emb1)),
                              dim=-1))
                emb1 = emb1_new
                emb2 = emb2_new

        d, _ = self.displacement_layer(s, P_src, P_tgt)
        return s, d
Esempio n. 5
0
    def forward(
        self,
        images,
        points,
        graphs,
        n_points,
        perm_mats,
        visualize_flag=False,
        visualization_params=None,
    ):

        global_list = []
        orig_graph_list = []
        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            # extract feature
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            global_list.append(
                self.final_layers(edges)[0].reshape((nodes.shape[0], -1)))
            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            # arrange features
            U = concat_features(feature_align(nodes, p, n_p, (256, 256)), n_p)
            F = concat_features(feature_align(edges, p, n_p, (256, 256)), n_p)
            node_features = torch.cat((U, F), dim=-1)
            graph.x = node_features

            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(graph)
            orig_graph_list.append(orig_graph)

        global_weights_list = [
            torch.cat([global_src, global_tgt], dim=-1)
            for global_src, global_tgt in lexico_iter(global_list)
        ]
        global_weights_list = [
            normalize_over_channels(g) for g in global_weights_list
        ]

        unary_costs_list = [
            self.vertex_affinity([item.x for item in g_1],
                                 [item.x for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Similarities to costs
        unary_costs_list = [[-x for x in unary_costs]
                            for unary_costs in unary_costs_list]

        if self.training:
            unary_costs_list = [
                [
                    x +
                    1.0 * gt[:dim_src, :dim_tgt]  # Add margin with alpha = 1.0
                    for x, gt, dim_src, dim_tgt in zip(unary_costs, perm_mat,
                                                       ns_src, ns_tgt)
                ] for unary_costs, perm_mat, (ns_src, ns_tgt) in zip(
                    unary_costs_list, perm_mats, lexico_iter(n_points))
            ]

        quadratic_costs_list = [
            self.edge_affinity([item.edge_attr for item in g_1],
                               [item.edge_attr
                                for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Aimilarities to costs
        quadratic_costs_list = [[-0.5 * x for x in quadratic_costs]
                                for quadratic_costs in quadratic_costs_list]

        if cfg.BB_GM.solver_name == "lpmp":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solvers = [
                GraphMatchingModule(
                    all_left_edges,
                    all_right_edges,
                    ns_src,
                    ns_tgt,
                    cfg.BB_GM.lambda_val,
                    cfg.BB_GM.solver_params,
                )
                for (all_left_edges, all_right_edges), (ns_src, ns_tgt) in zip(
                    lexico_iter(all_edges), lexico_iter(n_points))
            ]
            matchings = [
                gm_solver(unary_costs, quadratic_costs)
                for gm_solver, unary_costs, quadratic_costs in zip(
                    gm_solvers, unary_costs_list, quadratic_costs_list)
            ]
        elif cfg.BB_GM.solver_name == "multigraph":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solver = MultiGraphMatchingModule(all_edges, n_points,
                                                 cfg.BB_GM.lambda_val,
                                                 cfg.BB_GM.solver_params)
            matchings = gm_solver(unary_costs_list, quadratic_costs_list)
        else:
            raise ValueError(f"Unknown solver {cfg.BB_GM.solver_name}")

        if visualize_flag:
            easy_visualize(
                orig_graph_list,
                points,
                n_points,
                images,
                unary_costs_list,
                quadratic_costs_list,
                matchings,
                **visualization_params,
            )

        return matchings
Esempio n. 6
0
    def forward(self,
                src,
                tgt,
                P_src,
                P_tgt,
                G_src,
                G_tgt,
                H_src,
                H_tgt,
                ns_src,
                ns_tgt,
                K_G,
                K_H,
                edge_src,
                edge_tgt,
                edge_feat1,
                edge_feat2,
                perm_mat,
                type='img'):
        if type == 'img' or type == 'image':
            # extract feature
            src_node = self.node_layers(src)
            src_edge = self.edge_layers(src_node)
            tgt_node = self.node_layers(tgt)
            tgt_edge = self.edge_layers(tgt_node)

            # feature normalization
            src_node = self.l2norm(src_node)
            src_edge = self.l2norm(src_edge)
            tgt_node = self.l2norm(tgt_node)
            tgt_edge = self.l2norm(tgt_edge)

            # arrange features
            U_src = feature_align(src_node, P_src, ns_src, cfg.PAIR.RESCALE)
            F_src = feature_align(src_edge, P_src, ns_src, cfg.PAIR.RESCALE)
            U_tgt = feature_align(tgt_node, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
            F_tgt = feature_align(tgt_edge, P_tgt, ns_tgt, cfg.PAIR.RESCALE)
        elif type == 'feat' or type == 'feature':
            U_src = src[:, :src.shape[1] // 2, :]
            F_src = src[:, src.shape[1] // 2:, :]
            U_tgt = tgt[:, :tgt.shape[1] // 2, :]
            F_tgt = tgt[:, tgt.shape[1] // 2:, :]
        else:
            raise ValueError('unknown type string {}'.format(type))

        A_src = torch.bmm(G_src, H_src.transpose(1, 2))
        A_tgt = torch.bmm(G_tgt, H_tgt.transpose(1, 2))
        P1_src = torch.zeros_like(P_src)
        P2_tgt = torch.zeros_like(P_tgt)
        for k in range(P_src.shape[0]):
            for i in range(P_src.shape[1]):
                for j in range(P_tgt.shape[2]):
                    if torch.norm(P_src[k, i, :]) == 0:
                        P1_src[k, i, j] = 0
                        P2_tgt[k, i, j] = 0
                    else:
                        P1_src[k, i,
                               j] = P_src[k, i, j] / torch.norm(P_src[k, i, :])
                        P2_tgt[k, i,
                               j] = P_tgt[k, i, j] / torch.norm(P_tgt[k, i, :])

        ## Node embedding with unary geometric prior
        emb1, emb2 = torch.cat((U_src, F_src, P1_src.transpose(1, 2)),
                               dim=1).transpose(1, 2), torch.cat(
                                   (U_tgt, F_tgt, P2_tgt.transpose(1, 2)),
                                   dim=1).transpose(1, 2)

        for i in range(self.gnn_layer):

            gnn_layer = getattr(self, 'gnn_layer_{}'.format(i))
            if i == 0:
                emb1, emb2 = gnn_layer([A_src, emb1], [A_tgt, emb2])
            else:
                emb1_new = torch.cat((emb1, torch.bmm(s, emb2)), dim=-1)
                emb2_new = torch.cat(
                    (emb2, torch.bmm(s.transpose(1, 2), emb1)), dim=-1)
                emb1, emb2 = gnn_layer([AA_src, emb1_new], [BB_tgt, emb2_new])
            affinity = getattr(self, 'affinity_{}'.format(i))
            s = affinity(emb1, emb2)

            AA = torch.ones([s.shape[0], s.shape[1], s.shape[1]]).to(s.device)
            BB = torch.ones([s.shape[0], s.shape[2], s.shape[2]]).to(s.device)

            ## Commutative function f
            for kk in range(s.shape[0]):
                for ll in range(s.shape[1]):
                    for qq in range(s.shape[2]):
                        AA[kk, ll, qq] = torch.exp(
                            torch.matmul(
                                emb1[kk, ll, :] / torch.norm(emb1[kk, ll, :]),
                                emb1[kk, qq, :] / torch.norm(emb1[kk, qq, :])))
                        BB[kk, ll, qq] = torch.exp(
                            torch.matmul(
                                emb2[kk, ll, :] / torch.norm(emb2[kk, ll, :]),
                                emb2[kk, qq, :] / torch.norm(emb2[kk, qq, :])))

            ## Pairwise structural context
            AA_src = torch.mul(AA, A_src)
            BB_tgt = torch.mul(BB, A_tgt)

            if i == 1:
                ## QC-optimization
                X = s
                lb = 0.1  ## Balancing unary term and pairwise term
                for niter in range(3):
                    for ik in range(3):

                        perm_tgt = torch.bmm(torch.bmm(X, BB_tgt),
                                             X.transpose(1, 2))
                        P = (AA_src - perm_tgt)
                        V = -2 * P.cuda()
                        V_X = torch.bmm(V.transpose(1, 2), X)
                        V_XB = torch.bmm(V_X, BB_tgt)
                        VX = torch.bmm(V, X)
                        VX_B = torch.bmm(VX, BB_tgt.transpose(1, 2))
                        N = -s
                        G = lb * (V_XB + VX_B) + (1 - lb) * N
                        G_sim = -G - torch.min(-G)
                        S = self.sh_layer(G_sim, ns_src, ns_tgt)
                        lam = 2 / (ik + 2)
                        Xnew = X + lam * (S - X)
                        X = Xnew
                    X = self.sh_layer(X, ns_src, ns_tgt)
                s = 1 * s + 0.5 * X  ## For faster convergence

            ## Normalization in evaluation
            if self.training == False:
                for b in range(s.shape[0]):
                    s[b, :, :] = s[b, :, :].clone() / torch.max(
                        s[b, :, :].clone())

            s = self.sm_layer(s, ns_src, ns_tgt)
            s = self.sh_layer(s, ns_src, ns_tgt)

        return s, U_src, F_src, U_tgt, F_tgt, AA, BB