def test_generate_trimap(self):
        image = cv.imread('fg/1-1252426161dfXY.jpg')
        alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
        trimap = generate_trimap(alpha)
        self.assertEqual(trimap.shape, (615, 410))

        # ensure np.where works as expected.
        count = 0
        h, w = trimap.shape[:2]
        for i in range(h):
            for j in range(w):
                if trimap[i, j] == unknown_code:
                    count += 1
        x_indices, y_indices = np.where(trimap == unknown_code)
        num_unknowns = len(x_indices)
        self.assertEqual(count, num_unknowns)

        # ensure an unknown pixel is chosen
        ix = random.choice(range(num_unknowns))
        center_x = x_indices[ix]
        center_y = y_indices[ix]

        self.assertEqual(trimap[center_x, center_y], unknown_code)

        x, y = random_choice(trimap)
        # print(x, y)
        image = safe_crop(image, x, y)
        trimap = safe_crop(trimap, x, y)
        alpha = safe_crop(alpha, x, y)
        cv.imwrite('temp/test_generate_trimap_image.png', image)
        cv.imwrite('temp/test_generate_trimap_trimap.png', trimap)
        cv.imwrite('temp/test_generate_trimap_alpha.png', alpha)
 def test_flip(self):
     image = cv.imread('fg/1-1252426161dfXY.jpg')
     # print(image.shape)
     alpha = cv.imread('mask/1-1252426161dfXY.jpg', 0)
     trimap = generate_trimap(alpha)
     x, y = random_choice(trimap)
     image = safe_crop(image, x, y)
     trimap = safe_crop(trimap, x, y)
     alpha = safe_crop(alpha, x, y)
     image = np.fliplr(image)
     trimap = np.fliplr(trimap)
     alpha = np.fliplr(alpha)
     cv.imwrite('temp/test_flip_image.png', image)
     cv.imwrite('temp/test_flip_trimap.png', trimap)
     cv.imwrite('temp/test_flip_alpha.png', alpha)
 def test_resize(self):
     name = '0_0.png'
     filename = os.path.join('merged', name)
     image = cv.imread(filename)
     bg_h, bg_w = image.shape[:2]
     a = get_alpha(name)
     a_h, a_w = a.shape[:2]
     alpha = np.zeros((bg_h, bg_w), np.float32)
     alpha[0:a_h, 0:a_w] = a
     trimap = generate_trimap(alpha)
     # 剪切尺寸 320:640:480 = 3:1:1
     crop_size = (480, 480)
     x, y = random_choice(trimap, crop_size)
     image = safe_crop(image, x, y, crop_size)
     trimap = safe_crop(trimap, x, y, crop_size)
     alpha = safe_crop(alpha, x, y, crop_size)
     cv.imwrite('temp/test_resize_image.png', image)
     cv.imwrite('temp/test_resize_trimap.png', trimap)
     cv.imwrite('temp/test_resize_alpha.png', alpha)
Esempio n. 4
0
def matte(image_path, trimap_path, model): 

    # Read the background image
    #
    bgr_img = cv.imread(image_path)
    bg_h, bg_w = bgr_img.shape[:2]
    print('bg_h, bg_w: ' + str((bg_h, bg_w)))

    # Read the trimap in grayscale
    trimap = cv.imread(trimap_path, 0)

    # Crop
    different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
    crop_size = random.choice(different_sizes)
    x, y = random_choice(trimap, crop_size)
    print('x, y: ' + str((x, y)))

    bgr_img = safe_crop(bgr_img, x, y, crop_size)
    trimap = safe_crop(trimap, x, y, crop_size)
    cv.imwrite('matting/image.png', np.array(bgr_img).astype(np.uint8))
    cv.imwrite('matting/trimap.png', np.array(trimap).astype(np.uint8))

    x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
    x_test[0, :, :, 0:3] = bgr_img / 255.
    x_test[0, :, :, 3] = trimap / 255.

    y_pred = model.predict(x_test)
    # print('y_pred.shape: ' + str(y_pred.shape))

    y_pred = np.reshape(y_pred, (img_rows, img_cols))
    print(y_pred.shape)
    y_pred = y_pred * 255.0
    y_pred = get_final_output(y_pred, trimap)
    y_pred = y_pred.astype(np.uint8)

    out = y_pred.copy()
    cv.imwrite('matting/out.png', out)
Esempio n. 5
0
        bg_h, bg_w = bgr_img.shape[:2]
        print('bg_h, bg_w: ' + str((bg_h, bg_w)))

        a = get_alpha_test(image_name)
        a_h, a_w = a.shape[:2]
        print('a_h, a_w: ' + str((a_h, a_w)))

        alpha = np.zeros((bg_h, bg_w), np.float32)
        alpha[0:a_h, 0:a_w] = a
        trimap = generate_trimap(alpha)
        different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480), (640, 640)]
        crop_size = random.choice(different_sizes)
        x, y = random_choice(trimap, crop_size)
        print('x, y: ' + str((x, y)))

        bgr_img = safe_crop(bgr_img, x, y, crop_size)
        alpha = safe_crop(alpha, x, y, crop_size)
        trimap = safe_crop(trimap, x, y, crop_size)
        cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
        cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
        x_test[0, :, :, 0:3] = bgr_img / 255.
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
    def __getitem__(self, idx):

        # https://github.com/keras-team/keras/issues/3675#issuecomment-347697970
        if idx % 35 == 0:  # approx once per minute during first epoch on standard_p100
            t0 = time()
            gc.collect()
            t1 = time()
            print("Garbage collected batch %s in %.2fs" % (idx, t1 - t0))

        i = idx * self.batch_size

        length = min(self.batch_size, (len(self.names) - i))
        batch_x = np.empty((length, img_rows, img_cols, channel),
                           dtype=np.float32)
        batch_y = np.empty((length, img_rows, img_cols, 2), dtype=np.float32)

        bad_images = []

        # 1. Maybe pre-fetch the batch
        #

        # Construct the paths to download
        paths = []
        for i_batch in range(length):
            name = self.names[i]
            fcount = int(name.split('.')[0].split('_')[0])
            bcount = int(name.split('.')[0].split('_')[1])
            im_name = fg_files[fcount]
            bg_name = bg_files[bcount]
            paths.append((fg_base_path + im_name, a_base_path + im_name,
                          bg_base_path + bg_name))

        fg_cache_dir = os.path.join(cache_dir, 'fg')
        a_cache_dir = os.path.join(cache_dir, 'a')
        bg_cache_dir = os.path.join(cache_dir, 'bg')

        # Check whether they're cached
        is_cached = False
        for fg_path, _, _ in paths:
            if mio.is_cached(fg_path, fg_cache_dir):
                is_cached = True
                break

        if not is_cached:
            paths_by_dir = defaultdict(list)
            paths_by_dir[fg_cache_dir].extend([p[0] for p in paths])
            paths_by_dir[a_cache_dir].extend([p[1] for p in paths])
            paths_by_dir[bg_cache_dir].extend([p[2] for p in paths])

            # Cache the batch!
            retry = 0
            while True:
                try:
                    mio.batch_cache(paths_by_dir)
                    break
                except Exception as e:
                    retry = retry + 1
                    if retry >= 5:
                        raise e
                    sleep(1)

        # 2. Now process
        for i_batch in range(length):
            fg_path, a_path, bg_path = paths[i_batch]
            fg = mio.imread(fg_path, cache_dir=fg_cache_dir)
            a = mio.imread(a_path, flags=0, cache_dir=a_cache_dir)
            bg = mio.imread(bg_path, cache_dir=bg_cache_dir)
            if fg is None or a is None or bg is None:
                if fg is None:
                    bad = fg_path
                elif a is None:
                    bad = a_path
                else:
                    bad = bg_path
                print("Bad image: %s" % bad)
                bad_images.append(i_batch)
                print("Skipping bad image")
                i += 1
                continue

            image, alpha, fg, bg = process(fg, a, bg)

            trimap = generate_trimap(alpha)

            if not skip_crop:
                # crop size 320:640:480 = 1:1:1
                different_sizes = [(320, 320), (480, 480), (640, 640)]
                crop_size = random.choice(different_sizes)

                x, y = random_choice(trimap, crop_size)
                image = safe_crop(image, x, y, crop_size)
                alpha = safe_crop(alpha, x, y, crop_size)

            else:
                h, w = image.shape[:2]
                x = 0 if img_cols == w else (w - img_cols) // 2
                y = 0 if img_rows == h else (h - img_rows) // 2
                image = crop(image, x, y, (img_rows, img_cols))
                alpha = crop(alpha, x, y, (img_rows, img_cols))

            if channel == 4:
                trimap = generate_trimap(alpha)

            # Flip array left to right randomly (prob=1:1)
            if np.random.random_sample() > 0.5:
                image = np.fliplr(image)
                alpha = np.fliplr(alpha)

                if channel == 4:
                    trimap = np.fliplr(trimap)

            batch_x[i_batch, :, :, 0:3] = image / 255.
            if channel == 4:
                batch_x[i_batch, :, :, 3] = trimap / 255.

            if channel == 4:
                mask = np.equal(trimap, 128).astype(np.float32)
            else:
                mask = np.ones((img_rows, img_cols))

            batch_y[i_batch, :, :, 0] = alpha / 255.
            batch_y[i_batch, :, :, 1] = mask

            i += 1

        if bad_images:
            if len(bad_images) == length:
                print("WARNING: Empty batch!")
            else:
                batch_x = np.delete(batch_x, bad_images, 0)
                batch_y = np.delete(batch_y, bad_images, 0)

        return batch_x, batch_y