Example #1
0
def pred(image_np: np.ndarray, trimap_np: np.ndarray, model) -> np.ndarray:
    ''' Predict alpha, foreground and background.
        Parameters:
        image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        trimap_np -- two channel trimap, first background then foreground. Dimensions: (h, w, 2)
        Returns:
        fg: foreground image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        bg: background image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        alpha: alpha matte image between 0 and 1. Dimensions: (h, w)
    '''
    h, w = trimap_np.shape[:2]

    image_scale_np = scale_input(image_np, 1.0, cv2.INTER_LANCZOS4)
    trimap_scale_np = scale_input(trimap_np, 1.0, cv2.INTER_LANCZOS4)

    with torch.no_grad():

        image_torch = np_to_torch(image_scale_np)
        trimap_torch = np_to_torch(trimap_scale_np)

        trimap_transformed_torch = np_to_torch(trimap_transform(trimap_scale_np))
        image_transformed_torch = groupnorm_normalise_image(image_torch.clone(), format='nchw')

        output = model(image_torch, trimap_torch, image_transformed_torch, trimap_transformed_torch)

        output = cv2.resize(output[0].cpu().numpy().transpose((1, 2, 0)), (w, h), cv2.INTER_LANCZOS4)
    alpha = output[:, :, 0]
    fg = output[:, :, 1:4]
    bg = output[:, :, 4:7]

    alpha[trimap_np[:, :, 0] == 1] = 0
    alpha[trimap_np[:, :, 1] == 1] = 1
    fg[alpha == 1] = image_np[alpha == 1]
    bg[alpha == 0] = image_np[alpha == 0]
    return fg, bg, alpha
def pred(image_np: np.ndarray, trimap_np: np.ndarray, alpha_old_np: np.ndarray,
         model) -> np.ndarray:
    ''' Predict segmentation
        Parameters:
        image_np -- the image in rgb format between 0 and 1. Dimensions: (h, w, 3)
        trimap_np -- two channel trimap/Click map, first background then foreground. Dimensions: (h, w, 2)
        Returns:
        alpha: alpha matte/non-binary segmentation image between 0 and 1. Dimensions: (h, w)
    '''
    # return trimap_np[:,:,1] + (1-np.sum(trimap_np,-1))/2
    alpha_old_np = remove_non_fg_connected(alpha_old_np, trimap_np[:, :, 1])

    h, w = trimap_np.shape[:2]
    image_scale_np = scale_input(image_np, cv2.INTER_LANCZOS4)
    trimap_scale_np = scale_input(trimap_np, cv2.INTER_NEAREST)
    alpha_old_scale_np = scale_input(alpha_old_np, cv2.INTER_LANCZOS4)

    with torch.no_grad():

        image_torch = np_to_torch(image_scale_np)
        trimap_torch = np_to_torch(trimap_scale_np)
        alpha_old_torch = np_to_torch(alpha_old_scale_np[:, :, None])

        trimap_transformed_torch = np_to_torch(
            trimap_transform(trimap_scale_np))
        image_transformed_torch = groupnorm_normalise_image(
            image_torch.clone(), format='nchw')

        alpha = model(image_transformed_torch, trimap_transformed_torch,
                      alpha_old_torch, trimap_torch)
        alpha = cv2.resize(alpha[0].cpu().numpy().transpose((1, 2, 0)), (w, h),
                           cv2.INTER_LANCZOS4)
    alpha[trimap_np[:, :, 0] == 1] = 0
    alpha[trimap_np[:, :, 1] == 1] = 1

    alpha = remove_non_fg_connected(alpha, trimap_np[:, :, 1])
    return alpha