Beispiel #1
0
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        y_pred = np.reshape(y_pred, (img_rows, img_cols))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, trimap)
        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred.copy()
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

        sample_bg = sample_bgs[i]
        bg = cv.imread(os.path.join(bg_test, sample_bg))
        bh, bw = bg.shape[:2]
        wratio = img_cols / bw
        hratio = img_rows / bh
        ratio = wratio if wratio > hratio else hratio
        if ratio > 1:
Beispiel #2
0
        cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
        x_test[0, :, :, 0:3] = bgr_img / 255.
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, 320, 320, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        y_pred = np.reshape(y_pred, (img_rows, img_cols))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, trimap)
        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

    K.clear_session()
Beispiel #3
0
        y_pred = final.predict(x_test)
        print('predict has finished')

        y_pred = np.reshape(y_pred, (hight, wight))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, pd_trimap)
        y_pred = y_pred.astype(np.uint8)

        out_1 = y_pred.copy()
        all_out = np.zeros((bg_h, bg_w), np.float32)
        all_out[0:bg_h, 0:bg_w] = out_1[0:bg_h, 0:bg_w]

        all_out_for_loss = all_out.astype(np.uint8)

        sad_loss = compute_sad_loss(all_out_for_loss, alpha, trimap)
        total_sad_loss += sad_loss
        mse_loss = compute_mse_loss(all_out_for_loss, alpha, trimap)
        total_mse_loss += mse_loss
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f' % (sad_loss, mse_loss)
        print(str_msg)

        # draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/adobe_data/{}_out.png'.format(i), all_out)
        K.clear_session()
    mean_sad = total_sad_loss / (1.0 * test_data_number)
    mean_mse = total_mse_loss / (1.0 * test_data_number)
    print(mean_sad)
    with open('images/test_result.txt', 'w') as f1:
        f1.write(str(mean_sad) + " " + str(mean_mse))
Beispiel #4
0
def video_test(videopath, models_path, videomask_path, is_rotate=False):
    cam = cv.VideoCapture(videopath)
    width = int(cam.get(cv.CAP_PROP_FRAME_WIDTH))
    height = int(cam.get(cv.CAP_PROP_FRAME_HEIGHT))

    cam_mask = cv.VideoCapture(videomask_path)
    width_mask = int(cam_mask.get(cv.CAP_PROP_FRAME_WIDTH))
    height_mask = int(cam_mask.get(cv.CAP_PROP_FRAME_HEIGHT))

    pretrained_path = models_path  #'models/final.87-0.0372.hdf5'         #'models/final.42-0.0398.hdf5'
    encoder_decoder = build_encoder_decoder()
    final = build_refinement(encoder_decoder)
    final.load_weights(pretrained_path)
    print(final.summary())

    tri_videopath = videopath[:-4] + '_tri.mp4'
    tri_video = video_save(tri_videopath, w=width_mask, h=height_mask)

    matting_videopath = videopath[:-4] + '_out.mp4'
    matting_video = video_save(matting_videopath, w=width_mask, h=height_mask)

    comp_videopath = videopath[:-4] + '_comp.mp4'
    comp_video = video_save(comp_videopath, w=width_mask, h=height_mask)

    while (cam.isOpened() and cam_mask.isOpened()):
        start_time = time.time()
        ret, frame = cam.read()
        ret, frame_mask = cam_mask.read()

        if is_rotate:
            frame = imutils.rotate_bound(frame, 90)
            frame_mask = imutils.rotate_bound(frame_mask, 90)
        #             print(frame.shape)
        if frame is None:
            print('Error image!')
            break
        if frame_mask is None:
            print('Error mask image!')
            break

        bg_h, bg_w = height, width
        print('bg_h, bg_w: ' + str((bg_h, bg_w)))

        #         a = get_alpha_test(image_name)
        a = cv.cvtColor(frame_mask, cv.COLOR_BGR2GRAY)
        _, a = cv.threshold(a, 240, 255, cv.THRESH_BINARY)

        a_h, a_w = height_mask, width_mask
        print('a_h, a_w: ' + str((a_h, a_w)))

        alpha = np.zeros((bg_h, bg_w), np.float32)

        alpha[0:a_h, 0:a_w] = a
        trimap = generate_trimap_withmask(alpha)
        #         fg = np.array(np.greater_equal(a, 255).astype(np.float32))
        #         cv.imshow('test_show',fg)
        different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480),
                           (640, 640)]
        crop_size = random.choice(different_sizes)

        bgr_img = frame
        alpha = alpha
        trimap = trimap
        #         cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
        #         cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        #         cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        #         x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
        #         x_test[0, :, :, 0:3] = bgr_img / 255.
        #         x_test[0, :, :, 3] = trimap / 255.

        x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
        bgr_img1 = cv.resize(bgr_img, (320, 320))
        trimap1 = cv.resize(trimap, (320, 320))
        x_test[0, :, :, 0:3] = bgr_img1 / 255.
        x_test[0, :, :, 3] = trimap1 / 255.

        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        #         y_pred = np.reshape(y_pred, (img_rows, img_cols))
        y_pred = np.reshape(y_pred, (320, 320))
        print(y_pred.shape)
        y_pred = cv.resize(y_pred, (width, height))
        y_pred = y_pred * 255.0
        cv.imshow('pred', y_pred)
        y_pred = get_final_output(y_pred, trimap)

        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (
            sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred.copy()
        comp = composite_alpha(frame, out)
        draw_str(out, (10, 20), str_msg)

        trimap_show = np.stack((trimap, trimap, trimap), -1)
        out_show = cv.merge((out, out, out))
        #         print(trimap_show.shape,out_show.shape,comp.shape)
        tri_video.write(trimap_show)
        matting_video.write(out_show)
        comp_video.write(comp)
        #         cv.imwrite('images/{}_out.png'.format(filename[6:]), out)

        #### composite background
        #             sample_bg = sample_bgs[i]
        #             bg = cv.imread(os.path.join(bg_test, sample_bg))
        #             bh, bw = bg.shape[:2]
        #             wratio = img_cols / bw
        #             hratio = img_rows / bh
        #             ratio = wratio if wratio > hratio else hratio
        #             if ratio > 1:
        #                 bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
        #     #         im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows)
        #             im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows)
        #     #         cv.imwrite('images/{}_compose.png'.format(filename[6:]), im)
        #     #         cv.imwrite('images/{}_new_bg.png'.format(i), bg)

        print("Time: {:.2f} s / img".format(time.time() - start_time))

        cv.imshow('out', out)
        cv.imshow('frame', frame)
        cv.imshow('comp', comp)
        cv.imshow('trimap', trimap)

        if cv.waitKey(1) & 0xFF == ord('q'):
            break
    cam.release()
    cam_mask.release()
    tri_video.release()
    matting_video.release()
    comp_video.release()

    cv.destroyAllWindows()