def stich(self, im0, im1, unaries0, unaries1, labels0, labels1, pairwise_mask, segmentation): gc = gco.gco() gc.createGeneralGraph(self.tex_res**2, 2, True) gc.set_data_cost(np.dstack((unaries0, unaries1)).reshape(-1, 2)) edges_w = self._rgb_grad(im0, im1, labels0, labels1, pairwise_mask, segmentation) gc.set_all_neighbors(self.edges_from, self.edges_to, edges_w) gc.set_smooth_cost((1 - np.eye(2)) * 65) gc.swap() labels = gc.get_labels() gc.destroy_graph() labels = labels.reshape(self.tex_res, self.tex_res).astype(np.float32) label_maps = np.zeros((2, self.tex_res, self.tex_res)) for l in range(2): label_maps[l] = cv2.blur( np.float32(labels == l), (self.tex_res / 100, self.tex_res / 100)) # TODO norm_masks = np.sum(label_maps, axis=0) result = (np.atleast_3d(label_maps[0]) * im0 + np.atleast_3d(label_maps[1]) * im1) result[norm_masks != 0] /= np.atleast_3d(norm_masks)[norm_masks != 0] return result, labels
def test_cost_fun(): gc = gco.gco() gc.createGeneralGraph(3, 2) gc.set_data_cost(np.array([[8, 1], [8, 2], [2, 8]])) gc.set_all_neighbors(np.arange(0, 2), np.arange(1, 3), np.ones(2)) def cost_fun(s1, s2, l1, l2): if s1 == 0 and s2 == 1 and l1 == l2: return 5 return 8 gc.set_smooth_cost_function(cost_fun) gc.expansion() labels = gc.get_labels() assert np.array_equal(labels, np.array([1, 1, 0]))
def test_gc(): """ """ gc = gco.gco() gc.createGeneralGraph(3, 2, True) assert gc.handle is not None gc.destroy_graph()
def main(unwrap_dir, segm_out_file, gmm_out_file): iso_files = np.array(sorted(glob(os.path.join(unwrap_dir, '*_unwrap.jpg')))) segm_files = np.array(sorted(glob(os.path.join(unwrap_dir, '*_segm.png')))) vis_files = np.array( sorted(glob(os.path.join(unwrap_dir, '*_visibility.jpg')))) iso_mask = cv2.imread('assets/tex_mask_1000.png', flags=cv2.IMREAD_GRAYSCALE) / 255. iso_mask = cv2.resize(iso_mask, (1000, 1000), interpolation=cv2.INTER_NEAREST) voting = np.zeros((1000, 1000, len(LABELS_REDUCED))) gmms = {} gmm_pixels = {} for color_id in LABELS_REDUCED: gmms[color_id] = GaussianMixture(LABELS_MIXTURES[color_id]) gmm_pixels[color_id] = [] for frame_file, segm_file, vis_file in zip(iso_files, segm_files, vis_files): print('extract from {}...'.format(os.path.basename(frame_file))) frame = cv2.cvtColor(cv2.imread(frame_file), cv2.COLOR_BGR2HSV) / 255. tex_segm = read_segmentation(segm_file) tex_weights = 1 - cv2.imread(vis_file) / 255. tex_weights = np.sqrt(tex_weights) for i, color_id in enumerate(LABELS_REDUCED): if color_id != 'Unseen' and color_id != 'BG': where = np.all(tex_segm == LABELS_REDUCED[color_id], axis=2) voting[where, i] += tex_weights[where, 0] gmm_pixels[color_id].extend(frame[where].tolist()) for color_id in LABELS_REDUCED: if gmm_pixels[color_id]: print('GMM fit {}...'.format(color_id)) gmms[color_id].fit(np.array(gmm_pixels[color_id])) for i, color_id in enumerate(LABELS_REDUCED): if color_id == 'Unseen' or color_id == 'BG': voting[:, i] = -10 voting[iso_mask == 0] = 0 voting[iso_mask == 0, 0] = 1 unaries = np.ascontiguousarray((1 - voting / len(iso_files)) * 10) pairwise = np.ascontiguousarray(LABEL_COMP) seams = np.load('assets/basicModel_seams.npy') edge_idx = pkl.load(open('assets/basicModel_edge_idx_1000.pkl', 'rb')) dr_v = signal.convolve2d(iso_mask, [[-1, 1]])[:, 1:] dr_h = signal.convolve2d(iso_mask, [[-1], [1]])[1:, :] where_v = iso_mask - dr_v where_h = iso_mask - dr_h idxs = np.arange(1000**2).reshape(1000, 1000) v_edges_from = idxs[:-1, :][where_v[:-1, :] == 1].flatten() v_edges_to = idxs[1:, :][where_v[:-1, :] == 1].flatten() h_edges_from = idxs[:, :-1][where_h[:, :-1] == 1].flatten() h_edges_to = idxs[:, 1:][where_h[:, :-1] == 1].flatten() s_edges_from, s_edges_to = edges_seams(seams, 1000, edge_idx) edges_from = np.r_[v_edges_from, h_edges_from, s_edges_from] edges_to = np.r_[v_edges_to, h_edges_to, s_edges_to] edges_w = np.r_[np.ones_like(v_edges_from), np.ones_like(h_edges_from), np.ones_like(s_edges_from)] gc = gco.gco() gc.createGeneralGraph(1000**2, pairwise.shape[0], True) gc.set_data_cost(unaries.reshape(1000**2, pairwise.shape[0])) gc.set_all_neighbors(edges_from, edges_to, edges_w) gc.set_smooth_cost(pairwise) gc.swap(-1) labels = gc.get_labels().reshape(1000, 1000) gc.destroy_graph() segm_colors = np.zeros((1000, 1000, 3), dtype=np.uint8) for i, color_id in enumerate(LABELS_REDUCED): segm_colors[labels == i] = LABELS_REDUCED[color_id] cv2.imwrite('{}'.format(segm_out_file), segm_colors[:, :, ::-1]) pkl.dump(gmms, open(gmm_out_file, 'wb'))