class GrabCut: def __init__(self, img, mask, rect=None, gmm_components=5): self.img = np.asarray(img, dtype=np.float64) self.rows, self.cols, _ = img.shape self.mask = mask if rect is not None: self.mask[rect[1]:rect[1] + rect[3], rect[0]:rect[0] + rect[2]] = DRAW_PR_FG['val'] self.classify_pixels() # Best number of GMM components K suggested in paper self.gmm_components = gmm_components self.gamma = 50 # Best gamma suggested in paper formula (5) self.beta = 0 self.left_V = np.empty((self.rows, self.cols - 1)) self.upleft_V = np.empty((self.rows - 1, self.cols - 1)) self.up_V = np.empty((self.rows - 1, self.cols)) self.upright_V = np.empty((self.rows - 1, self.cols - 1)) self.bgd_gmm = None self.fgd_gmm = None self.comp_idxs = np.empty((self.rows, self.cols), dtype=np.uint32) self.gc_graph = None self.gc_graph_capacity = None # Edge capacities self.gc_source = self.cols * self.rows # "object" terminal S self.gc_sink = self.gc_source + 1 # "background" terminal T self.calc_beta_smoothness() self.init_GMMs() self.run() def calc_beta_smoothness(self): _left_diff = self.img[:, 1:] - self.img[:, :-1] _upleft_diff = self.img[1:, 1:] - self.img[:-1, :-1] _up_diff = self.img[1:, :] - self.img[:-1, :] _upright_diff = self.img[1:, :-1] - self.img[:-1, 1:] self.beta = np.sum(np.square(_left_diff)) + np.sum(np.square(_upleft_diff)) + \ np.sum(np.square(_up_diff)) + \ np.sum(np.square(_upright_diff)) self.beta = 1 / (2 * self.beta / ( # Each pixel has 4 neighbors (left, upleft, up, upright) 4 * self.cols * self.rows # The 1st column doesn't have left, upleft and the last column doesn't have upright - 3 * self.cols - 3 * self.rows # The first row doesn't have upleft, up and upright + 2)) # The first and last pixels in the 1st row are removed twice print('Beta:', self.beta) # Smoothness term V described in formula (11) self.left_V = self.gamma * np.exp(-self.beta * np.sum( np.square(_left_diff), axis=2)) self.upleft_V = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum( np.square(_upleft_diff), axis=2)) self.up_V = self.gamma * np.exp(-self.beta * np.sum( np.square(_up_diff), axis=2)) self.upright_V = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum( np.square(_upright_diff), axis=2)) def classify_pixels(self): self.bgd_indexes = np.where(np.logical_or( self.mask == DRAW_BG['val'], self.mask == DRAW_PR_BG['val'])) self.fgd_indexes = np.where(np.logical_or( self.mask == DRAW_FG['val'], self.mask == DRAW_PR_FG['val'])) assert self.bgd_indexes[0].size > 0 assert self.fgd_indexes[0].size > 0 print('(pr_)bgd count: %d, (pr_)fgd count: %d' % ( self.bgd_indexes[0].size, self.fgd_indexes[0].size)) def init_GMMs(self): self.bgd_gmm = GaussianMixture(self.img[self.bgd_indexes]) self.fgd_gmm = GaussianMixture(self.img[self.fgd_indexes]) def assign_GMMs_components(self): """Step 1 in Figure 3: Assign GMM components to pixels""" self.comp_idxs[self.bgd_indexes] = self.bgd_gmm.which_component( self.img[self.bgd_indexes]) self.comp_idxs[self.fgd_indexes] = self.fgd_gmm.which_component( self.img[self.fgd_indexes]) def learn_GMMs(self): """Step 2 in Figure 3: Learn GMM parameters from data z""" self.bgd_gmm.fit(self.img[self.bgd_indexes], self.comp_idxs[self.bgd_indexes]) self.fgd_gmm.fit(self.img[self.fgd_indexes], self.comp_idxs[self.fgd_indexes]) def construct_gc_graph(self): bgd_indexes = np.where(self.mask.reshape(-1) == DRAW_BG['val']) fgd_indexes = np.where(self.mask.reshape(-1) == DRAW_FG['val']) pr_indexes = np.where(np.logical_or( self.mask.reshape(-1) == DRAW_PR_BG['val'], self.mask.reshape(-1) == DRAW_PR_FG['val'])) print('bgd count: %d, fgd count: %d, uncertain count: %d' % ( len(bgd_indexes[0]), len(fgd_indexes[0]), len(pr_indexes[0]))) edges = [] self.gc_graph_capacity = [] # t-links edges.extend( list(zip([self.gc_source] * pr_indexes[0].size, pr_indexes[0]))) _D = -np.log(self.bgd_gmm.calc_prob(self.img.reshape(-1, 3)[pr_indexes])) self.gc_graph_capacity.extend(_D.tolist()) assert len(edges) == len(self.gc_graph_capacity) edges.extend( list(zip([self.gc_sink] * pr_indexes[0].size, pr_indexes[0]))) _D = -np.log(self.fgd_gmm.calc_prob(self.img.reshape(-1, 3)[pr_indexes])) self.gc_graph_capacity.extend(_D.tolist()) assert len(edges) == len(self.gc_graph_capacity) edges.extend( list(zip([self.gc_source] * bgd_indexes[0].size, bgd_indexes[0]))) self.gc_graph_capacity.extend([0] * bgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend( list(zip([self.gc_sink] * bgd_indexes[0].size, bgd_indexes[0]))) self.gc_graph_capacity.extend([9 * self.gamma] * bgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend( list(zip([self.gc_source] * fgd_indexes[0].size, fgd_indexes[0]))) self.gc_graph_capacity.extend([9 * self.gamma] * fgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend( list(zip([self.gc_sink] * fgd_indexes[0].size, fgd_indexes[0]))) self.gc_graph_capacity.extend([0] * fgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) # print(len(edges)) # n-links img_indexes = np.arange(self.rows * self.cols, dtype=np.uint32).reshape(self.rows, self.cols) mask1 = img_indexes[:, 1:].reshape(-1) mask2 = img_indexes[:, :-1].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.left_V.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) mask1 = img_indexes[1:, 1:].reshape(-1) mask2 = img_indexes[:-1, :-1].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend( self.upleft_V.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) mask1 = img_indexes[1:, :].reshape(-1) mask2 = img_indexes[:-1, :].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.up_V.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) mask1 = img_indexes[1:, :-1].reshape(-1) mask2 = img_indexes[:-1, 1:].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend( self.upright_V.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) assert len(edges) == 4 * self.cols * self.rows - 3 * (self.cols + self.rows) + 2 + \ 2 * self.cols * self.rows self.gc_graph = ig.Graph(self.cols * self.rows + 2) self.gc_graph.add_edges(edges) def estimate_segmentation(self): """Step 3 in Figure 3: Estimate segmentation""" mincut = self.gc_graph.st_mincut( self.gc_source, self.gc_sink, self.gc_graph_capacity) print('foreground pixels: %d, background pixels: %d' % ( len(mincut.partition[0]), len(mincut.partition[1]))) pr_indexes = np.where(np.logical_or( self.mask == DRAW_PR_BG['val'], self.mask == DRAW_PR_FG['val'])) img_indexes = np.arange(self.rows * self.cols, dtype=np.uint32).reshape(self.rows, self.cols) self.mask[pr_indexes] = np.where(np.isin(img_indexes[pr_indexes], mincut.partition[0]), DRAW_PR_FG['val'], DRAW_PR_BG['val']) self.classify_pixels() def calc_energy(self): U = 0 for ci in range(self.gmm_components): idx = np.where(np.logical_and(self.comp_idxs == ci, np.logical_or( self.mask == DRAW_BG['val'], self.mask == DRAW_PR_BG['val']))) U += np.sum(-np.log(self.bgd_gmm.coefs[ci] * self.bgd_gmm.calc_score(self.img[idx], ci))) idx = np.where(np.logical_and(self.comp_idxs == ci, np.logical_or( self.mask == DRAW_FG['val'], self.mask == DRAW_PR_FG['val']))) U += np.sum(-np.log(self.fgd_gmm.coefs[ci] * self.fgd_gmm.calc_score(self.img[idx], ci))) V = 0 mask = self.mask.copy() mask[np.where(mask == DRAW_PR_BG['val'])] = DRAW_BG['val'] mask[np.where(mask == DRAW_PR_FG['val'])] = DRAW_FG['val'] V += np.sum(self.left_V * (mask[:, 1:] == mask[:, :-1])) V += np.sum(self.upleft_V * (mask[1:, 1:] == mask[:-1, :-1])) V += np.sum(self.up_V * (mask[1:, :] == mask[:-1, :])) V += np.sum(self.upright_V * (mask[1:, :-1] == mask[:-1, 1:])) return U, V, U + V def run(self, num_iters=1, skip_learn_GMMs=False): print('skip learn GMMs:', skip_learn_GMMs) for _ in range(num_iters): if not skip_learn_GMMs: self.assign_GMMs_components() self.learn_GMMs() self.construct_gc_graph() self.estimate_segmentation() skip_learn_GMMs = False
class GrabCut: def __init__(self, img, mask, bounding_box=None, gmm_components=5): self.img = np.asarray(img, dtype=np.float64) self.rows, self.cols = img.shape[0],img.shape[1] self.mask = mask if bounding_box is not None: self.mask[bounding_box[1]:bounding_box[1] + bounding_box[3], bounding_box[0]:bounding_box[0] + bounding_box[2]] = DRAW_PR_FG['val'] self.classify_pixels() # Best number of GMM components K suggested in paper self.gmm_components = gmm_components self.gamma = 50 # Best gamma suggested in paper formula (5) self.beta = 0 self.dis_W = np.empty((self.rows, self.cols - 1)) self.dis_NW = np.empty((self.rows - 1, self.cols - 1)) self.dis_N = np.empty((self.rows - 1, self.cols)) self.dis_NE = np.empty((self.rows - 1, self.cols - 1)) self.bgd_gmm = None self.fgd_gmm = None self.comp_idxs = np.empty((self.rows, self.cols), dtype=np.uint32) self.gc_graph = None self.gc_graph_capacity = None # Edge capacities self.gc_source = self.cols * self.rows # "object" terminal S self.gc_sink = self.gc_source + 1 # "background" terminal T #calculate ||Zm-Zn||^2 (four directions enough) left_diffr = self.img[:, 1:] - self.img[:, :-1] upleft_diffr = self.img[1:, 1:] - self.img[:-1, :-1] up_diffr = self.img[1:, :] - self.img[:-1, :] upright_diffr = self.img[1:, :-1] - self.img[:-1, 1:] #calculate Beta self.beta = np.sum(np.square(left_diffr)) + np.sum(np.square(upleft_diffr)) + np.sum(np.square(up_diffr)) + np.sum(np.square(upright_diffr)) self.beta = 1 / (2 * self.beta / (4 * self.cols * self.rows - 3 * self.cols - 3 * self.rows + 2)) # Smoothness term V described in formula (11) # define V edges self.dis_W = self.gamma * np.exp(-self.beta * np.sum(np.square(left_diffr), axis=2)) self.dis_NW = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum(np.square(upleft_diffr), axis=2)) self.dis_N = self.gamma * np.exp(-self.beta * np.sum(np.square(up_diffr), axis=2)) self.dis_NE = self.gamma / np.sqrt(2) * np.exp(-self.beta * np.sum(np.square(upright_diffr), axis=2)) # Apply GaussianMixture for both foreground and background self.bgd_gmm = GaussianMixture(self.img[self.bgd_indexes]) self.fgd_gmm = GaussianMixture(self.img[self.fgd_indexes]) def classify_pixels(self): self.bgd_indexes = np.where(np.logical_or(self.mask == DRAW_BG['val'], self.mask == DRAW_PR_BG['val'])) self.fgd_indexes = np.where(np.logical_or(self.mask == DRAW_FG['val'], self.mask == DRAW_PR_FG['val'])) assert self.bgd_indexes[0].size > 0 assert self.fgd_indexes[0].size > 0 #print('(pr_)bgd count: %d, (pr_)fgd count: %d' % (self.bgd_indexes[0].size, self.fgd_indexes[0].size)) def assign_GMMs_components(self): '''Step 1 in Figure 3: Assign GMM components to pixels''' self.comp_idxs[self.bgd_indexes] = self.bgd_gmm.which_component(self.img[self.bgd_indexes]) self.comp_idxs[self.fgd_indexes] = self.fgd_gmm.which_component(self.img[self.fgd_indexes]) def learn_GMMs(self): '''Step 2 in Figure 3: Learn GMM parameters from data z''' self.bgd_gmm.fit(self.img[self.bgd_indexes],self.comp_idxs[self.bgd_indexes]) self.fgd_gmm.fit(self.img[self.fgd_indexes],self.comp_idxs[self.fgd_indexes]) def construct_gc_graph(self): bgd_indexes = np.where(self.mask.reshape(-1) == DRAW_BG['val']) fgd_indexes = np.where(self.mask.reshape(-1) == DRAW_FG['val']) pr_indexes = np.where(np.logical_or(self.mask.reshape(-1) == DRAW_PR_BG['val'], self.mask.reshape(-1) == DRAW_PR_FG['val'])) #print('bgd count: %d, fgd count: %d, uncertain count: %d' % (len(bgd_indexes[0]), len(fgd_indexes[0]), len(pr_indexes[0]))) edges = [] self.gc_graph_capacity = [] # t-links # construct the cut graph edges.extend(list(zip([self.gc_source] * pr_indexes[0].size, pr_indexes[0]))) _D = -np.log(self.bgd_gmm.calc_prob(self.img.reshape(-1, 3)[pr_indexes])) self.gc_graph_capacity.extend(_D.tolist()) assert len(edges) == len(self.gc_graph_capacity) edges.extend(list(zip([self.gc_sink] * pr_indexes[0].size, pr_indexes[0]))) _D = -np.log(self.fgd_gmm.calc_prob(self.img.reshape(-1, 3)[pr_indexes])) self.gc_graph_capacity.extend(_D.tolist()) assert len(edges) == len(self.gc_graph_capacity) edges.extend(list(zip([self.gc_source] * bgd_indexes[0].size, bgd_indexes[0]))) self.gc_graph_capacity.extend([0] * bgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend(list(zip([self.gc_sink] * bgd_indexes[0].size, bgd_indexes[0]))) self.gc_graph_capacity.extend([9 * self.gamma] * bgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend(list(zip([self.gc_source] * fgd_indexes[0].size, fgd_indexes[0]))) self.gc_graph_capacity.extend([9 * self.gamma] * fgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) edges.extend(list(zip([self.gc_sink] * fgd_indexes[0].size, fgd_indexes[0]))) self.gc_graph_capacity.extend([0] * fgd_indexes[0].size) assert len(edges) == len(self.gc_graph_capacity) img_indexes = np.arange(self.rows * self.cols, dtype=np.uint32).reshape(self.rows, self.cols) # W Direction mask1 = img_indexes[:, 1:].reshape(-1) mask2 = img_indexes[:, :-1].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.dis_W.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) # NW Direction mask1 = img_indexes[1:, 1:].reshape(-1) mask2 = img_indexes[:-1, :-1].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.dis_NW.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) # N Direction mask1 = img_indexes[1:, :].reshape(-1) mask2 = img_indexes[:-1, :].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.dis_N.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) # NE Direction mask1 = img_indexes[1:, :-1].reshape(-1) mask2 = img_indexes[:-1, 1:].reshape(-1) edges.extend(list(zip(mask1, mask2))) self.gc_graph_capacity.extend(self.dis_NE.reshape(-1).tolist()) assert len(edges) == len(self.gc_graph_capacity) assert len(edges) == 4 * self.cols * self.rows - 3 * (self.cols + self.rows) + 2 + 2 * self.cols * self.rows self.gc_graph = ig.Graph(self.cols * self.rows + 2) self.gc_graph.add_edges(edges) def estimate_segmentation(self): """Step 3 in Figure 3: Estimate segmentation""" mincut = self.gc_graph.st_mincut(self.gc_source, self.gc_sink, self.gc_graph_capacity) print('foreground pixels: %d, background pixels: %d' % (len(mincut.partition[0]), len(mincut.partition[1]))) pr_indexes = np.where(np.logical_or(self.mask == DRAW_PR_BG['val'], self.mask == DRAW_PR_FG['val'])) img_indexes = np.arange(self.rows * self.cols,dtype=np.uint32).reshape(self.rows, self.cols) self.mask[pr_indexes] = np.where(np.isin(img_indexes[pr_indexes], mincut.partition[0]),DRAW_PR_FG['val'], DRAW_PR_BG['val']) self.classify_pixels() def run(self, num_iters=1): for _ in range(num_iters): self.assign_GMMs_components() self.learn_GMMs() self.construct_gc_graph() self.estimate_segmentation()