def get_multi_cluster_samples(self,
                               idx,
                               ks=None,
                               mode="all",
                               cls_list=None,
                               shuffle=True):
     assert ks is not None
     assert cls_list is not None
     anno_list = []
     perm_mat_list = []
     adj_mat_list = []
     node_label_list = []
     graph_indices_list = []
     for i, cls in enumerate(cls_list):
         if type(cls) == int:
             clss = self.classes[cls]
             cls_list[i] = clss
     if len(cls_list) < len(ks):
         for i in range(len(cls_list), len(ks)):
             cls_num = random.randrange(0, len(self.classes))
             cls = self.classes[cls_num]
             while cls in cls_list:
                 cls_num = random.randrange(0, len(self.classes))
                 cls = self.classes[cls_num]
             cls_list.append(cls)
     assert len(ks) == len(cls_list)
     for k, cls in zip(ks, cls_list):
         annos, perms, adjs, nodes, _ = self.get_k_samples(idx=idx,
                                                           k=k,
                                                           mode=mode,
                                                           cls=cls,
                                                           shuffle=shuffle)
         anno_list.extend(annos)
         perm_mat_list.extend(perms)
         adj_mat_list.extend(adjs)
         node_label_list.extend(nodes)
         # if type(cls) == str:
         #     cls = self.classes.index(cls)
         graph_indices_list.extend([cls for _ in range(k)])
     cnt = 0
     perm_mat_true_list = []
     for i, j in lexico_iter(graph_indices_list):
         if i == j:
             perm_mat_true_list.append(perm_mat_list[cnt])
             cnt += 1
         else:
             perm_mat_true_list.append(np.zeros((1, 1)))
     assert cnt == len(perm_mat_list)
     return anno_list, perm_mat_true_list, adj_mat_list, node_label_list, graph_indices_list
    def get_k_samples(self,
                      idx,
                      k,
                      mode,
                      cls=None,
                      shuffle=True,
                      num_iterations=200):
        """
        Randomly get a sample of k objects from VOC-Berkeley keypoints dataset
        :param idx: Index of datapoint to sample, None for random sampling
        :param k: number of datapoints in sample
        :param mode: sampling strategy
        :param cls: None for random class, or specify for a certain set
        :param shuffle: random shuffle the keypoints
        :param num_iterations: maximum number of iterations for sampling a datapoint
        :return: (k samples of data, k choose 2 groundtruth permutation matrices)
        """
        if idx is not None:
            raise NotImplementedError(
                "No indexed sampling implemented for willow.")
        if cls is None:
            cls = random.randrange(0, len(self.classes))
        elif type(cls) == str:
            cls = self.classes.index(cls)
        assert type(cls) == int and 0 <= cls < len(self.classes)

        if mode == "superset" and k == 2:
            anno_list, perm_mat = self.get_pair_superset(
                cls=cls, shuffle=shuffle, num_iterations=num_iterations)
            return anno_list, [perm_mat]

        anno_list = []
        cls_list = []
        for xml_name in random.sample(self.mat_list[cls], k):
            anno_dict = self.__get_anno_dict(xml_name, cls)
            if shuffle:
                random.shuffle(anno_dict["keypoints"])
            if len(anno_dict["keypoints"]) is not (10 + self.outlier):
                return self.get_k_samples(idx, k, mode, cls, shuffle,
                                          num_iterations)
            anno_list.append(anno_dict)
            cls_list.append(cls)

        perm_mat_list = [
            np.zeros([len(_["keypoints"]) for _ in anno_pair],
                     dtype=np.float32) for anno_pair in lexico_iter(anno_list)
        ]
        for n, ((s1, s2), (cls1, cls2)) in enumerate(
                zip(lexico_iter(anno_list), lexico_iter(cls_list))):
            row_list = []
            col_list = []
            if cls1 == cls2:
                for i, keypoint in enumerate(s1["keypoints"]):
                    for j, _keypoint in enumerate(s2["keypoints"]):
                        if keypoint["name"] == _keypoint["name"] and keypoint[
                                "name"] is not 10:
                            perm_mat_list[n][i, j] = 1
                            row_list.append(i)
                            col_list.append(j)
                            break
                if mode == "all":
                    pass
                elif mode == "rectangle" and k == 2:  # so far only implemented for k = 2
                    row_list.sort()
                    perm_mat_list[n] = perm_mat_list[n][row_list, :]
                    s1["keypoints"] = [s1["keypoints"][i] for i in row_list]
                    assert perm_mat_list[n].size == len(s1["keypoints"]) * len(
                        s2["keypoints"])
                elif mode == "intersection" and k == 2:  # so far only implemented for k = 2
                    row_list.sort()
                    col_list.sort()
                    perm_mat_list[n] = perm_mat_list[n][row_list, :]
                    perm_mat_list[n] = perm_mat_list[n][:, col_list]
                    s1["keypoints"] = [s1["keypoints"][i] for i in row_list]
                    s2["keypoints"] = [s2["keypoints"][j] for j in col_list]
                else:
                    raise NotImplementedError(
                        f"Unknown sampling strategy {mode}")

        return anno_list, perm_mat_list
def easy_visualize(
    graphs,
    positions,
    n_points,
    images,
    unary_costs,
    quadratic_costs,
    matchings,
    true_matchings,
    string_info,
    reduced_vis,
    produce_pdf=True,
):
    """

    :param graphs: [num_graphs, bs, ...]
    :param positions: [num_graphs, bs, 2, max_n_p]
    :param n_points: [num_graphs, bs, n_p]
    :param images: [num_graphs, bs, size, size]
    :param unary_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p]
    :param quadratic_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p]
    :param matchings: [num_graphs \choose 2, bs, max_n_p, max_n_p]
    """
    positions = [[p[:num] for p, num in zip(pos, n_p)] for pos, n_p in zip(positions, n_points)]
    matchings = [
        [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)]
        for match, (n_p_x_batch, n_p_y_batch) in zip(matchings, lexico_iter(n_points))
    ]
    true_matchings = [
        [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)]
        for match, (n_p_x_batch, n_p_y_batch) in zip(true_matchings, lexico_iter(n_points))
    ]

    visualization_string = "visualization"
    latex_file = lu.LatexFile(visualization_string)
    vis_dir = os.path.join(cfg.model_dir, visualization_string)
    unnorm = UnNormalize(cfg.NORM_MEANS, cfg.NORM_STD)
    images = [[unnorm(im) for im in im_b] for im_b in images]

    if not os.path.exists(vis_dir):
        os.makedirs(vis_dir)

    batch = zip(
        zip(*graphs),
        zip(*positions),
        zip(*images),
        zip(*unary_costs),
        zip(*quadratic_costs),
        zip(*matchings),
        zip(*true_matchings),
    )
    for b, (graph_l, pos_l, im_l, unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l) in enumerate(batch):
        if not reduced_vis:
            files_single = []
            for i, (graph, pos, im) in enumerate(zip(graph_l, pos_l, im_l)):
                f_single = visualize_graph(graph, pos, im, suffix=f"single_{i}", idx=b, vis_dir=vis_dir)
                f_single_simple = visualize_graph(
                    graph, pos, im, suffix=f"single_simple_{i}", idx=b, vis_dir=vis_dir, mode="triang"
                )
                files_single.append(f_single)
                files_single.append(f_single_simple)
            latex_file.add_section_from_figures(
                name=f"Single Graphs ({b})", list_of_filenames=files_single, common_scale=0.7
            )

        files_mge = []
        for (
            unary_c,
            quadratic_c,
            matching,
            true_matching,
            (graph_src, graph_tgt),
            (pos_src, pos_tgt),
            (im_src, im_tgt),
            (i, j),
        ) in n_and_l_iter_parallel(
            n=[unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l], l=[graph_l, pos_l, im_l], enum=True
        ):
            im_mge, p_mge, graph_mge, edges_corrct_mge, node_colors_mge, true_graph = merge_images_and_graphs(
                graph_src, graph_tgt, pos_src, pos_tgt, im_src, im_tgt, new_edges=matching, true_edges=true_matching
            )
            f_mge = visualize_graph(
                graph_mge,
                p_mge,
                im_mge,
                suffix=f"mge_{i}-{j}",
                idx=b,
                vis_dir=vis_dir,
                mode="only_edges",
                edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge],
                node_colors=node_colors_mge,
                true_graph=true_graph,
            )
            files_mge.append(f_mge)

            if not reduced_vis:
                f_mge_nodes = visualize_graph(
                    graph_mge,
                    p_mge,
                    im_mge,
                    suffix=f"mge_nodes_{i}-{j}",
                    idx=b,
                    vis_dir=vis_dir,
                    mode="only_nodes",
                    edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge],
                    node_colors=node_colors_mge,
                    true_graph=true_graph,
                )
                files_mge.append(f_mge_nodes)
                costs_and_matchings = dict(
                    unary_cost=unary_c, quadratic_cost=quadratic_c, matchings=matching, true_matching=true_matching
                )
                for key, value in costs_and_matchings.items():
                    latex_file.add_section_from_dataframe(
                        name=f"{key} ({b}, {i}-{j})", dataframe=pd.DataFrame(value).round(2)
                    )

        latex_file.add_section_from_figures(name=f"Matched Graphs ({b})", list_of_filenames=files_mge, common_scale=0.7)

    time = "{date:%Y-%m-%d_%H-%M-%S}".format(date=datetime.datetime.now())
    suffix = f"{string_info}_{time}"
    output_file = os.path.join(vis_dir, f"{visualization_string}_{suffix}.pdf")
    if produce_pdf:
        latex_file.produce_pdf(output_file=output_file)
Exemple #4
0
    def forward(
        self,
        images,
        points,
        graphs,
        n_points,
        perm_mats,
        visualize_flag=False,
        visualization_params=None,
    ):

        global_list = []
        orig_graph_list = []
        for image, p, n_p, graph in zip(images, points, n_points, graphs):
            # extract feature
            nodes = self.node_layers(image)
            edges = self.edge_layers(nodes)

            global_list.append(
                self.final_layers(edges)[0].reshape((nodes.shape[0], -1)))
            nodes = normalize_over_channels(nodes)
            edges = normalize_over_channels(edges)

            # arrange features
            U = concat_features(feature_align(nodes, p, n_p, (256, 256)), n_p)
            F = concat_features(feature_align(edges, p, n_p, (256, 256)), n_p)
            node_features = torch.cat((U, F), dim=-1)
            graph.x = node_features

            graph = self.message_pass_node_features(graph)
            orig_graph = self.build_edge_features_from_node_features(graph)
            orig_graph_list.append(orig_graph)

        global_weights_list = [
            torch.cat([global_src, global_tgt], dim=-1)
            for global_src, global_tgt in lexico_iter(global_list)
        ]
        global_weights_list = [
            normalize_over_channels(g) for g in global_weights_list
        ]

        unary_costs_list = [
            self.vertex_affinity([item.x for item in g_1],
                                 [item.x for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Similarities to costs
        unary_costs_list = [[-x for x in unary_costs]
                            for unary_costs in unary_costs_list]

        if self.training:
            unary_costs_list = [
                [
                    x +
                    1.0 * gt[:dim_src, :dim_tgt]  # Add margin with alpha = 1.0
                    for x, gt, dim_src, dim_tgt in zip(unary_costs, perm_mat,
                                                       ns_src, ns_tgt)
                ] for unary_costs, perm_mat, (ns_src, ns_tgt) in zip(
                    unary_costs_list, perm_mats, lexico_iter(n_points))
            ]

        quadratic_costs_list = [
            self.edge_affinity([item.edge_attr for item in g_1],
                               [item.edge_attr
                                for item in g_2], global_weights)
            for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list),
                                                  global_weights_list)
        ]

        # Aimilarities to costs
        quadratic_costs_list = [[-0.5 * x for x in quadratic_costs]
                                for quadratic_costs in quadratic_costs_list]

        if cfg.BB_GM.solver_name == "lpmp":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solvers = [
                GraphMatchingModule(
                    all_left_edges,
                    all_right_edges,
                    ns_src,
                    ns_tgt,
                    cfg.BB_GM.lambda_val,
                    cfg.BB_GM.solver_params,
                )
                for (all_left_edges, all_right_edges), (ns_src, ns_tgt) in zip(
                    lexico_iter(all_edges), lexico_iter(n_points))
            ]
            matchings = [
                gm_solver(unary_costs, quadratic_costs)
                for gm_solver, unary_costs, quadratic_costs in zip(
                    gm_solvers, unary_costs_list, quadratic_costs_list)
            ]
        elif cfg.BB_GM.solver_name == "multigraph":
            all_edges = [[item.edge_index for item in graph]
                         for graph in orig_graph_list]
            gm_solver = MultiGraphMatchingModule(all_edges, n_points,
                                                 cfg.BB_GM.lambda_val,
                                                 cfg.BB_GM.solver_params)
            matchings = gm_solver(unary_costs_list, quadratic_costs_list)
        else:
            raise ValueError(f"Unknown solver {cfg.BB_GM.solver_name}")

        if visualize_flag:
            easy_visualize(
                orig_graph_list,
                points,
                n_points,
                images,
                unary_costs_list,
                quadratic_costs_list,
                matchings,
                **visualization_params,
            )

        return matchings
    def get_k_samples(self,
                      idx,
                      k,
                      mode,
                      cls=None,
                      shuffle=True,
                      num_iterations=200):
        """
        Randomly get a sample of k objects from VOC-Berkeley keypoints dataset
        :param idx: Index of datapoint to sample, None for random sampling
        :param k: number of datapoints in sample
        :param mode: sampling strategy
        :param cls: None for random class, or specify for a certain set
        :param shuffle: random shuffle the keypoints
        :param num_iterations: maximum number of iterations for sampling a datapoint
        :return: (k samples of data, k \choose 2 groundtruth permutation matrices)
        """
        if idx is not None:
            raise NotImplementedError(
                "No indexed sampling implemented for PVOC.")
        if cls is None:
            cls = random.randrange(0, len(self.classes))
        elif type(cls) == str:
            cls = self.classes.index(cls)
        assert type(cls) == int and 0 <= cls < len(self.classes)

        cls_choose = [
            self.classes.index(x) for x in self.classes if x not in ban_classes
        ]
        if cls not in cls_choose:
            cls = cls_choose[random.randrange(0, len(cls_choose))]

        if mode == "superset" and k == 2:  # superset sampling only valid for pairs
            anno_list, perm_mat = self.get_pair_superset(
                cls=cls, shuffle=shuffle, num_iterations=num_iterations)
            return anno_list, [perm_mat]
        elif mode == "intersection":
            for i in range(num_iterations):
                xml_used = list(random.sample(self.xml_list[cls], 2))
                anno_dict_1, anno_dict_2 = [
                    self.__get_anno_dict(xml, cls) for xml in xml_used
                ]
                kp_names_1 = [
                    keypoint["name"] for keypoint in anno_dict_1["keypoints"]
                ]
                kp_names_2 = [
                    keypoint["name"] for keypoint in anno_dict_2["keypoints"]
                ]
                kp_names_filtered = set(kp_names_1).intersection(kp_names_2)
                anno_dict_1["keypoints"] = [
                    kp for kp in anno_dict_1["keypoints"]
                    if kp["name"] in kp_names_2
                ]
                anno_dict_2["keypoints"] = [
                    kp for kp in anno_dict_2["keypoints"]
                    if kp["name"] in kp_names_1
                ]

                anno_list = [anno_dict_1, anno_dict_2]
                for j in range(num_iterations):
                    if j > 2 * len(self.xml_list[cls]) or len(anno_list) == k:
                        break
                    xml = random.choice(self.xml_list[cls])
                    anno_dict = self.__get_anno_dict(xml, cls)
                    anno_dict["keypoints"] = [
                        kp for kp in anno_dict["keypoints"]
                        if kp["name"] in kp_names_filtered
                    ]
                    if len(anno_dict["keypoints"]) > len(
                            kp_names_filtered) // 2 and xml not in xml_used:
                        xml_used.append(xml)
                        anno_list.append(anno_dict)
                if len(anno_list
                       ) == k:  # k samples found that match restrictions
                    break
            assert len(anno_list) == k
        elif mode == "all":
            anno_list = []
            for xml_name in random.sample(self.xml_list[cls],
                                          len(self.xml_list[cls])):
                anno_dict = self.__get_anno_dict(xml_name, cls)
                anno_list.append(anno_dict)
                if len(anno_list
                       ) == k:  # k samples found that match restrictions
                    break

        if shuffle:
            for anno_dict in anno_list:
                random.shuffle(anno_dict["keypoints"])

        # build permutation matrices
        perm_mat_list = [
            np.zeros([len(_["keypoints"]) for _ in anno_pair],
                     dtype=np.float32) for anno_pair in lexico_iter(anno_list)
        ]
        for n, (s1, s2) in enumerate(lexico_iter(anno_list)):
            for i, keypoint in enumerate(s1["keypoints"]):
                for j, _keypoint in enumerate(s2["keypoints"]):
                    if keypoint["name"] == _keypoint["name"]:
                        perm_mat_list[n][i, j] = 1

        # build ground truth graph
        adj_mat_list = []
        for s in anno_list:
            kp_idx = [
                KPT_NAMES[self.classes[cls]].index(kp["name"])
                for kp in s["keypoints"]
            ]
            adj_mat = self.adj_list[cls][kp_idx, :]
            adj_mat = adj_mat[:, kp_idx]
            adj_mat_list.append(adj_mat)
            if np.sum(adj_mat) == 0:
                return self.get_k_samples(idx,
                                          k,
                                          mode,
                                          cls=cls,
                                          shuffle=shuffle,
                                          num_iterations=num_iterations)

        # build node ground truth label
        node_label_list = []
        for s in anno_list:
            node_label = [
                KPT_NAMES[self.classes[cls]].index(kp["name"])
                for kp in s["keypoints"]
            ]
            node_label_list.append(node_label)

        return anno_list, perm_mat_list, adj_mat_list, node_label_list, cls
    def get_k_samples(self,
                      idx,
                      k,
                      mode,
                      cls=None,
                      shuffle=True,
                      num_iterations=200):
        """
        Randomly get a sample of k objects from VOC-Berkeley keypoints dataset
        :param mode: sampling strategy
        :param shuffle: random shuffle the keypoints
        :return: (k samples of data, k choose 2 groundtruth permutation matrices)
        """
        cls = self.cls_list[self.pin]
        xml_batch_list = self.fetch()
        anno_list = [
            self.__get_anno_dict(xml_name, cls) for xml_name in xml_batch_list
        ]

        if mode == "intersection":
            assert self.num_sample == 2 and len(
                anno_list
            ) == 2, "intersection mode are only supported by pair sample"
            anno_dict_1, anno_dict_2 = anno_list[0], anno_list[1]
            kp_names_1 = [
                keypoint["name"] for keypoint in anno_dict_1["keypoints"]
            ]
            kp_names_2 = [
                keypoint["name"] for keypoint in anno_dict_2["keypoints"]
            ]
            anno_dict_1["keypoints"] = [
                kp for kp in anno_dict_1["keypoints"]
                if kp["name"] in kp_names_2
            ]
            anno_dict_2["keypoints"] = [
                kp for kp in anno_dict_2["keypoints"]
                if kp["name"] in kp_names_1
            ]
            anno_list = [anno_dict_1, anno_dict_2]

        if shuffle:
            for anno_dict in anno_list:
                random.shuffle(anno_dict["keypoints"])

        # build permutation matrices
        perm_mat_list = [
            np.zeros([len(_["keypoints"]) for _ in anno_pair],
                     dtype=np.float32) for anno_pair in lexico_iter(anno_list)
        ]
        for n, (s1, s2) in enumerate(lexico_iter(anno_list)):
            for i, keypoint in enumerate(s1["keypoints"]):
                for j, _keypoint in enumerate(s2["keypoints"]):
                    if keypoint["name"] == _keypoint["name"]:
                        perm_mat_list[n][i, j] = 1

        # build ground truth graph
        adj_mat_list = []
        for s in anno_list:
            kp_idx = [
                KPT_NAMES[self.classes[cls]].index(kp["name"])
                for kp in s["keypoints"]
            ]
            adj_mat = self.adj_list[cls][kp_idx, :]
            adj_mat = adj_mat[:, kp_idx]
            adj_mat_list.append(adj_mat)

        # build node ground truth label
        node_label_list = []
        for s in anno_list:
            node_label = [
                KPT_NAMES[self.classes[cls]].index(kp["name"])
                for kp in s["keypoints"]
            ]
            node_label_list.append(node_label)

        return anno_list, perm_mat_list, adj_mat_list, node_label_list, cls