def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray: ''' Predict alpha, foreground and background. Parameters: image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3) trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2) Returns: fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3) bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3) alpha: alpha matte image between 0 and 1. Dimensions: (h, w) ''' h, w = trimap_np.shape[:2] image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4) trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4) with torch.no_grad(): image_torch = np_to_torch(image_scale_np) trimap_torch = np_to_torch(trimap_scale_np) trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np)) image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw') output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch) output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4) alpha = output[:, :, 0] fg = output[:, :, 1:4] bg = output[:, :, 4:7] alpha[trimap_np[:, :, 0] == 1] = 0 alpha[trimap_np[:, :, 1] == 1] = 1 fg[alpha == 1] = image_np[alpha == 1] bg[alpha == 0] = image_np[alpha == 0] return fg, bg, alpha
def pred(image_np: np.ndarray, trimap_np: np.ndarray, alpha_old_np: np.ndarray, model) -> np.ndarray: ''' Predict segmentation Parameters: image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3) trimap_np -- two channel trimap/Click map, first background then foreground. Dimensions: (h, w, 2) Returns: alpha: alpha matte/non-binary segmentation image between 0 and 1. Dimensions: (h, w) ''' # return trimap_np[:,:,1] + (1-np.sum(trimap_np,-1))/2 alpha_old_np = remove_non_fg_connected(alpha_old_np, trimap_np[:, :, 1]) h, w = trimap_np.shape[:2] image_scale_np = scale_input(image_np, cv2.INTER_LANCZOS4) trimap_scale_np = scale_input(trimap_np, cv2.INTER_NEAREST) alpha_old_scale_np = scale_input(alpha_old_np, cv2.INTER_LANCZOS4) with torch.no_grad(): image_torch = np_to_torch(image_scale_np) trimap_torch = np_to_torch(trimap_scale_np) alpha_old_torch = np_to_torch(alpha_old_scale_np[:, :, None]) trimap_transformed_torch = np_to_torch( trimap_transform(trimap_scale_np)) image_transformed_torch = groupnorm_normalise_image( image_torch.clone(), format='nchw') alpha = model(image_transformed_torch, trimap_transformed_torch, alpha_old_torch, trimap_torch) alpha = cv2.resize(alpha[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4) alpha[trimap_np[:, :, 0] == 1] = 0 alpha[trimap_np[:, :, 1] == 1] = 1 alpha = remove_non_fg_connected(alpha, trimap_np[:, :, 1]) return alpha