def get_multi_cluster_samples(self, idx, ks=None, mode="all", cls_list=None, shuffle=True): assert ks is not None assert cls_list is not None anno_list = [] perm_mat_list = [] adj_mat_list = [] node_label_list = [] graph_indices_list = [] for i, cls in enumerate(cls_list): if type(cls) == int: clss = self.classes[cls] cls_list[i] = clss if len(cls_list) < len(ks): for i in range(len(cls_list), len(ks)): cls_num = random.randrange(0, len(self.classes)) cls = self.classes[cls_num] while cls in cls_list: cls_num = random.randrange(0, len(self.classes)) cls = self.classes[cls_num] cls_list.append(cls) assert len(ks) == len(cls_list) for k, cls in zip(ks, cls_list): annos, perms, adjs, nodes, _ = self.get_k_samples(idx=idx, k=k, mode=mode, cls=cls, shuffle=shuffle) anno_list.extend(annos) perm_mat_list.extend(perms) adj_mat_list.extend(adjs) node_label_list.extend(nodes) # if type(cls) == str: # cls = self.classes.index(cls) graph_indices_list.extend([cls for _ in range(k)]) cnt = 0 perm_mat_true_list = [] for i, j in lexico_iter(graph_indices_list): if i == j: perm_mat_true_list.append(perm_mat_list[cnt]) cnt += 1 else: perm_mat_true_list.append(np.zeros((1, 1))) assert cnt == len(perm_mat_list) return anno_list, perm_mat_true_list, adj_mat_list, node_label_list, graph_indices_list
def get_k_samples(self, idx, k, mode, cls=None, shuffle=True, num_iterations=200): """ Randomly get a sample of k objects from VOC-Berkeley keypoints dataset :param idx: Index of datapoint to sample, None for random sampling :param k: number of datapoints in sample :param mode: sampling strategy :param cls: None for random class, or specify for a certain set :param shuffle: random shuffle the keypoints :param num_iterations: maximum number of iterations for sampling a datapoint :return: (k samples of data, k choose 2 groundtruth permutation matrices) """ if idx is not None: raise NotImplementedError( "No indexed sampling implemented for willow.") if cls is None: cls = random.randrange(0, len(self.classes)) elif type(cls) == str: cls = self.classes.index(cls) assert type(cls) == int and 0 <= cls < len(self.classes) if mode == "superset" and k == 2: anno_list, perm_mat = self.get_pair_superset( cls=cls, shuffle=shuffle, num_iterations=num_iterations) return anno_list, [perm_mat] anno_list = [] cls_list = [] for xml_name in random.sample(self.mat_list[cls], k): anno_dict = self.__get_anno_dict(xml_name, cls) if shuffle: random.shuffle(anno_dict["keypoints"]) if len(anno_dict["keypoints"]) is not (10 + self.outlier): return self.get_k_samples(idx, k, mode, cls, shuffle, num_iterations) anno_list.append(anno_dict) cls_list.append(cls) perm_mat_list = [ np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) for anno_pair in lexico_iter(anno_list) ] for n, ((s1, s2), (cls1, cls2)) in enumerate( zip(lexico_iter(anno_list), lexico_iter(cls_list))): row_list = [] col_list = [] if cls1 == cls2: for i, keypoint in enumerate(s1["keypoints"]): for j, _keypoint in enumerate(s2["keypoints"]): if keypoint["name"] == _keypoint["name"] and keypoint[ "name"] is not 10: perm_mat_list[n][i, j] = 1 row_list.append(i) col_list.append(j) break if mode == "all": pass elif mode == "rectangle" and k == 2: # so far only implemented for k = 2 row_list.sort() perm_mat_list[n] = perm_mat_list[n][row_list, :] s1["keypoints"] = [s1["keypoints"][i] for i in row_list] assert perm_mat_list[n].size == len(s1["keypoints"]) * len( s2["keypoints"]) elif mode == "intersection" and k == 2: # so far only implemented for k = 2 row_list.sort() col_list.sort() perm_mat_list[n] = perm_mat_list[n][row_list, :] perm_mat_list[n] = perm_mat_list[n][:, col_list] s1["keypoints"] = [s1["keypoints"][i] for i in row_list] s2["keypoints"] = [s2["keypoints"][j] for j in col_list] else: raise NotImplementedError( f"Unknown sampling strategy {mode}") return anno_list, perm_mat_list
def easy_visualize( graphs, positions, n_points, images, unary_costs, quadratic_costs, matchings, true_matchings, string_info, reduced_vis, produce_pdf=True, ): """ :param graphs: [num_graphs, bs, ...] :param positions: [num_graphs, bs, 2, max_n_p] :param n_points: [num_graphs, bs, n_p] :param images: [num_graphs, bs, size, size] :param unary_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p] :param quadratic_costs: [num_graphs \choose 2, bs, max_n_p, max_n_p] :param matchings: [num_graphs \choose 2, bs, max_n_p, max_n_p] """ positions = [[p[:num] for p, num in zip(pos, n_p)] for pos, n_p in zip(positions, n_points)] matchings = [ [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)] for match, (n_p_x_batch, n_p_y_batch) in zip(matchings, lexico_iter(n_points)) ] true_matchings = [ [m[:n_p_x, :n_p_y] for m, n_p_x, n_p_y in zip(match, n_p_x_batch, n_p_y_batch)] for match, (n_p_x_batch, n_p_y_batch) in zip(true_matchings, lexico_iter(n_points)) ] visualization_string = "visualization" latex_file = lu.LatexFile(visualization_string) vis_dir = os.path.join(cfg.model_dir, visualization_string) unnorm = UnNormalize(cfg.NORM_MEANS, cfg.NORM_STD) images = [[unnorm(im) for im in im_b] for im_b in images] if not os.path.exists(vis_dir): os.makedirs(vis_dir) batch = zip( zip(*graphs), zip(*positions), zip(*images), zip(*unary_costs), zip(*quadratic_costs), zip(*matchings), zip(*true_matchings), ) for b, (graph_l, pos_l, im_l, unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l) in enumerate(batch): if not reduced_vis: files_single = [] for i, (graph, pos, im) in enumerate(zip(graph_l, pos_l, im_l)): f_single = visualize_graph(graph, pos, im, suffix=f"single_{i}", idx=b, vis_dir=vis_dir) f_single_simple = visualize_graph( graph, pos, im, suffix=f"single_simple_{i}", idx=b, vis_dir=vis_dir, mode="triang" ) files_single.append(f_single) files_single.append(f_single_simple) latex_file.add_section_from_figures( name=f"Single Graphs ({b})", list_of_filenames=files_single, common_scale=0.7 ) files_mge = [] for ( unary_c, quadratic_c, matching, true_matching, (graph_src, graph_tgt), (pos_src, pos_tgt), (im_src, im_tgt), (i, j), ) in n_and_l_iter_parallel( n=[unary_costs_l, quadratic_costs_l, matchings_l, true_matchings_l], l=[graph_l, pos_l, im_l], enum=True ): im_mge, p_mge, graph_mge, edges_corrct_mge, node_colors_mge, true_graph = merge_images_and_graphs( graph_src, graph_tgt, pos_src, pos_tgt, im_src, im_tgt, new_edges=matching, true_edges=true_matching ) f_mge = visualize_graph( graph_mge, p_mge, im_mge, suffix=f"mge_{i}-{j}", idx=b, vis_dir=vis_dir, mode="only_edges", edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge], node_colors=node_colors_mge, true_graph=true_graph, ) files_mge.append(f_mge) if not reduced_vis: f_mge_nodes = visualize_graph( graph_mge, p_mge, im_mge, suffix=f"mge_nodes_{i}-{j}", idx=b, vis_dir=vis_dir, mode="only_nodes", edge_colors=[colors[2] if corr else colors[3] for corr in edges_corrct_mge], node_colors=node_colors_mge, true_graph=true_graph, ) files_mge.append(f_mge_nodes) costs_and_matchings = dict( unary_cost=unary_c, quadratic_cost=quadratic_c, matchings=matching, true_matching=true_matching ) for key, value in costs_and_matchings.items(): latex_file.add_section_from_dataframe( name=f"{key} ({b}, {i}-{j})", dataframe=pd.DataFrame(value).round(2) ) latex_file.add_section_from_figures(name=f"Matched Graphs ({b})", list_of_filenames=files_mge, common_scale=0.7) time = "{date:%Y-%m-%d_%H-%M-%S}".format(date=datetime.datetime.now()) suffix = f"{string_info}_{time}" output_file = os.path.join(vis_dir, f"{visualization_string}_{suffix}.pdf") if produce_pdf: latex_file.produce_pdf(output_file=output_file)
def forward( self, images, points, graphs, n_points, perm_mats, visualize_flag=False, visualization_params=None, ): global_list = [] orig_graph_list = [] for image, p, n_p, graph in zip(images, points, n_points, graphs): # extract feature nodes = self.node_layers(image) edges = self.edge_layers(nodes) global_list.append( self.final_layers(edges)[0].reshape((nodes.shape[0], -1))) nodes = normalize_over_channels(nodes) edges = normalize_over_channels(edges) # arrange features U = concat_features(feature_align(nodes, p, n_p, (256, 256)), n_p) F = concat_features(feature_align(edges, p, n_p, (256, 256)), n_p) node_features = torch.cat((U, F), dim=-1) graph.x = node_features graph = self.message_pass_node_features(graph) orig_graph = self.build_edge_features_from_node_features(graph) orig_graph_list.append(orig_graph) global_weights_list = [ torch.cat([global_src, global_tgt], dim=-1) for global_src, global_tgt in lexico_iter(global_list) ] global_weights_list = [ normalize_over_channels(g) for g in global_weights_list ] unary_costs_list = [ self.vertex_affinity([item.x for item in g_1], [item.x for item in g_2], global_weights) for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list), global_weights_list) ] # Similarities to costs unary_costs_list = [[-x for x in unary_costs] for unary_costs in unary_costs_list] if self.training: unary_costs_list = [ [ x + 1.0 * gt[:dim_src, :dim_tgt] # Add margin with alpha = 1.0 for x, gt, dim_src, dim_tgt in zip(unary_costs, perm_mat, ns_src, ns_tgt) ] for unary_costs, perm_mat, (ns_src, ns_tgt) in zip( unary_costs_list, perm_mats, lexico_iter(n_points)) ] quadratic_costs_list = [ self.edge_affinity([item.edge_attr for item in g_1], [item.edge_attr for item in g_2], global_weights) for (g_1, g_2), global_weights in zip(lexico_iter(orig_graph_list), global_weights_list) ] # Aimilarities to costs quadratic_costs_list = [[-0.5 * x for x in quadratic_costs] for quadratic_costs in quadratic_costs_list] if cfg.BB_GM.solver_name == "lpmp": all_edges = [[item.edge_index for item in graph] for graph in orig_graph_list] gm_solvers = [ GraphMatchingModule( all_left_edges, all_right_edges, ns_src, ns_tgt, cfg.BB_GM.lambda_val, cfg.BB_GM.solver_params, ) for (all_left_edges, all_right_edges), (ns_src, ns_tgt) in zip( lexico_iter(all_edges), lexico_iter(n_points)) ] matchings = [ gm_solver(unary_costs, quadratic_costs) for gm_solver, unary_costs, quadratic_costs in zip( gm_solvers, unary_costs_list, quadratic_costs_list) ] elif cfg.BB_GM.solver_name == "multigraph": all_edges = [[item.edge_index for item in graph] for graph in orig_graph_list] gm_solver = MultiGraphMatchingModule(all_edges, n_points, cfg.BB_GM.lambda_val, cfg.BB_GM.solver_params) matchings = gm_solver(unary_costs_list, quadratic_costs_list) else: raise ValueError(f"Unknown solver {cfg.BB_GM.solver_name}") if visualize_flag: easy_visualize( orig_graph_list, points, n_points, images, unary_costs_list, quadratic_costs_list, matchings, **visualization_params, ) return matchings
def get_k_samples(self, idx, k, mode, cls=None, shuffle=True, num_iterations=200): """ Randomly get a sample of k objects from VOC-Berkeley keypoints dataset :param idx: Index of datapoint to sample, None for random sampling :param k: number of datapoints in sample :param mode: sampling strategy :param cls: None for random class, or specify for a certain set :param shuffle: random shuffle the keypoints :param num_iterations: maximum number of iterations for sampling a datapoint :return: (k samples of data, k \choose 2 groundtruth permutation matrices) """ if idx is not None: raise NotImplementedError( "No indexed sampling implemented for PVOC.") if cls is None: cls = random.randrange(0, len(self.classes)) elif type(cls) == str: cls = self.classes.index(cls) assert type(cls) == int and 0 <= cls < len(self.classes) cls_choose = [ self.classes.index(x) for x in self.classes if x not in ban_classes ] if cls not in cls_choose: cls = cls_choose[random.randrange(0, len(cls_choose))] if mode == "superset" and k == 2: # superset sampling only valid for pairs anno_list, perm_mat = self.get_pair_superset( cls=cls, shuffle=shuffle, num_iterations=num_iterations) return anno_list, [perm_mat] elif mode == "intersection": for i in range(num_iterations): xml_used = list(random.sample(self.xml_list[cls], 2)) anno_dict_1, anno_dict_2 = [ self.__get_anno_dict(xml, cls) for xml in xml_used ] kp_names_1 = [ keypoint["name"] for keypoint in anno_dict_1["keypoints"] ] kp_names_2 = [ keypoint["name"] for keypoint in anno_dict_2["keypoints"] ] kp_names_filtered = set(kp_names_1).intersection(kp_names_2) anno_dict_1["keypoints"] = [ kp for kp in anno_dict_1["keypoints"] if kp["name"] in kp_names_2 ] anno_dict_2["keypoints"] = [ kp for kp in anno_dict_2["keypoints"] if kp["name"] in kp_names_1 ] anno_list = [anno_dict_1, anno_dict_2] for j in range(num_iterations): if j > 2 * len(self.xml_list[cls]) or len(anno_list) == k: break xml = random.choice(self.xml_list[cls]) anno_dict = self.__get_anno_dict(xml, cls) anno_dict["keypoints"] = [ kp for kp in anno_dict["keypoints"] if kp["name"] in kp_names_filtered ] if len(anno_dict["keypoints"]) > len( kp_names_filtered) // 2 and xml not in xml_used: xml_used.append(xml) anno_list.append(anno_dict) if len(anno_list ) == k: # k samples found that match restrictions break assert len(anno_list) == k elif mode == "all": anno_list = [] for xml_name in random.sample(self.xml_list[cls], len(self.xml_list[cls])): anno_dict = self.__get_anno_dict(xml_name, cls) anno_list.append(anno_dict) if len(anno_list ) == k: # k samples found that match restrictions break if shuffle: for anno_dict in anno_list: random.shuffle(anno_dict["keypoints"]) # build permutation matrices perm_mat_list = [ np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) for anno_pair in lexico_iter(anno_list) ] for n, (s1, s2) in enumerate(lexico_iter(anno_list)): for i, keypoint in enumerate(s1["keypoints"]): for j, _keypoint in enumerate(s2["keypoints"]): if keypoint["name"] == _keypoint["name"]: perm_mat_list[n][i, j] = 1 # build ground truth graph adj_mat_list = [] for s in anno_list: kp_idx = [ KPT_NAMES[self.classes[cls]].index(kp["name"]) for kp in s["keypoints"] ] adj_mat = self.adj_list[cls][kp_idx, :] adj_mat = adj_mat[:, kp_idx] adj_mat_list.append(adj_mat) if np.sum(adj_mat) == 0: return self.get_k_samples(idx, k, mode, cls=cls, shuffle=shuffle, num_iterations=num_iterations) # build node ground truth label node_label_list = [] for s in anno_list: node_label = [ KPT_NAMES[self.classes[cls]].index(kp["name"]) for kp in s["keypoints"] ] node_label_list.append(node_label) return anno_list, perm_mat_list, adj_mat_list, node_label_list, cls
def get_k_samples(self, idx, k, mode, cls=None, shuffle=True, num_iterations=200): """ Randomly get a sample of k objects from VOC-Berkeley keypoints dataset :param mode: sampling strategy :param shuffle: random shuffle the keypoints :return: (k samples of data, k choose 2 groundtruth permutation matrices) """ cls = self.cls_list[self.pin] xml_batch_list = self.fetch() anno_list = [ self.__get_anno_dict(xml_name, cls) for xml_name in xml_batch_list ] if mode == "intersection": assert self.num_sample == 2 and len( anno_list ) == 2, "intersection mode are only supported by pair sample" anno_dict_1, anno_dict_2 = anno_list[0], anno_list[1] kp_names_1 = [ keypoint["name"] for keypoint in anno_dict_1["keypoints"] ] kp_names_2 = [ keypoint["name"] for keypoint in anno_dict_2["keypoints"] ] anno_dict_1["keypoints"] = [ kp for kp in anno_dict_1["keypoints"] if kp["name"] in kp_names_2 ] anno_dict_2["keypoints"] = [ kp for kp in anno_dict_2["keypoints"] if kp["name"] in kp_names_1 ] anno_list = [anno_dict_1, anno_dict_2] if shuffle: for anno_dict in anno_list: random.shuffle(anno_dict["keypoints"]) # build permutation matrices perm_mat_list = [ np.zeros([len(_["keypoints"]) for _ in anno_pair], dtype=np.float32) for anno_pair in lexico_iter(anno_list) ] for n, (s1, s2) in enumerate(lexico_iter(anno_list)): for i, keypoint in enumerate(s1["keypoints"]): for j, _keypoint in enumerate(s2["keypoints"]): if keypoint["name"] == _keypoint["name"]: perm_mat_list[n][i, j] = 1 # build ground truth graph adj_mat_list = [] for s in anno_list: kp_idx = [ KPT_NAMES[self.classes[cls]].index(kp["name"]) for kp in s["keypoints"] ] adj_mat = self.adj_list[cls][kp_idx, :] adj_mat = adj_mat[:, kp_idx] adj_mat_list.append(adj_mat) # build node ground truth label node_label_list = [] for s in anno_list: node_label = [ KPT_NAMES[self.classes[cls]].index(kp["name"]) for kp in s["keypoints"] ] node_label_list.append(node_label) return anno_list, perm_mat_list, adj_mat_list, node_label_list, cls