Пример #1
0
    def inference(self, x, x_mask, y, y_mask, training):

        encoder_padding_mask = create_padding_mask(x_mask)

        x = clean_inputs(x, x_mask, self.params["d_input"])

        transformer_out = self.transformer(x,
                                           encoder_padding_mask,
                                           training=training)
        cnn_out = self.cnn(x, training=training)
        x = tf.concat((x, cnn_out, transformer_out), -1)

        x = clean_inputs(x, x_mask, 3 * self.params["d_input"])

        rota_predictions = \
            self.birnn(x, x_mask, training=training)
        loss = None
        if training == True:
            loss = compute_mse_loss(rota_predictions, y[:, :, 15:23],
                                    y_mask[:, :, 15:23])

        return rota_predictions, loss
Пример #2
0
            # testing
            feed_data_ = np.concatenate([data, tri_map, gradient, roughness],
                                        axis=2)
            feed_data_ = np.expand_dims(np.transpose(feed_data_, (2, 0, 1)),
                                        axis=0)
            shape_model.feed_input_with_shape(feed_data_)
            duration, pred = shape_model.predict_with_shape_data()
            log.info("Processed %s, consumed %f second." %
                     (item_name, duration))
            pred = np.where(np.equal(tri_map[:, :, 0], unknown_code), pred,
                            tri_map[:, :, 0])
            pred = cv2.resize(
                pred, (original_shape[1], original_shape[0]), \
                interpolation=cv2.INTER_CUBIC
            )
            mse = compute_mse_loss(pred, gt, tri_map_original)
            shape_mse += mse
            log.info("mse for %s is: %f" % (item_name, mse))
            output_img = Image.fromarray(pred).convert("L")
            output_dir = os.path.join("../", "shape-test-output")
            if not os.path.exists(output_dir): os.mkdir(output_dir)
            output_filename = os.path.join(output_dir, item_name)
            output_img.save(output_filename)
            time_ += duration

        shape_mse /= nums
        log.info("Mean time consumption every single image is: %f" %
                 (time_ / nums))
        log.info("Mean mse is: %f" % shape_mse)
Пример #3
0
        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        y_pred = np.reshape(y_pred, (img_rows, img_cols))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, trimap)
        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred.copy()
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

        sample_bg = sample_bgs[i]
        bg = cv.imread(os.path.join(bg_test, sample_bg))
        bh, bw = bg.shape[:2]
        wratio = img_cols / bw
        hratio = img_rows / bh
        ratio = wratio if wratio > hratio else hratio
        if ratio > 1:
            bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
Пример #4
0
        y_pred = final.predict(x_test)
        print('predict has finished')

        y_pred = np.reshape(y_pred, (hight, wight))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, pd_trimap)
        y_pred = y_pred.astype(np.uint8)

        out_1 = y_pred.copy()
        all_out = np.zeros((bg_h, bg_w), np.float32)
        all_out[0:bg_h, 0:bg_w] = out_1[0:bg_h, 0:bg_w]

        all_out_for_loss = all_out.astype(np.uint8)

        sad_loss = compute_sad_loss(all_out_for_loss, alpha, trimap)
        total_sad_loss += sad_loss
        mse_loss = compute_mse_loss(all_out_for_loss, alpha, trimap)
        total_mse_loss += mse_loss
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f' % (sad_loss, mse_loss)
        print(str_msg)

        # draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/adobe_data/{}_out.png'.format(i), all_out)
        K.clear_session()
    mean_sad = total_sad_loss / (1.0 * test_data_number)
    mean_mse = total_mse_loss / (1.0 * test_data_number)
    print(mean_sad)
    with open('images/test_result.txt', 'w') as f1:
        f1.write(str(mean_sad) + " " + str(mean_mse))
Пример #5
0
        cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
        x_test[0, :, :, 0:3] = bgr_img / 255.
        x_test[0, :, :, 3] = trimap / 255.

        y_true = np.empty((1, 320, 320, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        y_pred = np.reshape(y_pred, (img_rows, img_cols))
        print(y_pred.shape)
        y_pred = y_pred * 255.0
        y_pred = get_final_output(y_pred, trimap)
        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred
        draw_str(out, (10, 20), str_msg)
        cv.imwrite('images/{}_out.png'.format(i), out)

    K.clear_session()
Пример #6
0
    args = parser.parse_args()

    mse_loss = []
    sad_loss = []

    ### loss_unknown only consider the unknown regions, i.e. trimap==128, as trimap-based methods do
    mse_loss_unknown = []
    sad_loss_unknown = []

    for img in os.listdir(args.label_dir):
        print(img)
        pred = cv2.imread(os.path.join(args.pred_dir, img), 0).astype(np.float32)
        label = cv2.imread(os.path.join(args.label_dir, img), 0).astype(np.float32)
        trimap = cv2.imread(os.path.join(args.trimap_dir, img), 0).astype(np.float32)

        mse_loss_unknown_ = compute_mse_loss(pred, label, trimap)
        sad_loss_unknown_ = comput_sad_loss(pred, label, trimap)[0]

        trimap[...] = 128

        mse_loss_ = compute_mse_loss(pred, label, trimap)
        sad_loss_ = comput_sad_loss(pred, label, trimap)[0]

        print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_)
        print('Unknown Region: MSE:', mse_loss_unknown_, ' SAD:', sad_loss_unknown_)

        mse_loss_unknown.append(mse_loss_unknown_)
        sad_loss_unknown.append(sad_loss_unknown_)

        mse_loss.append(mse_loss_)
        sad_loss.append(sad_loss_)
Пример #7
0
    def inference(self, x, x_mask, y, y_mask, training):

        encoder_padding_mask = create_padding_mask(x_mask)

        x = clean_inputs(x, x_mask, self.params["d_input"])

        transformer_out = self.transformer(x,
                                           encoder_padding_mask,
                                           training=training)
        cnn_out = self.cnn(x, training=training)
        x = tf.concat((x, cnn_out, transformer_out), -1)

        x = clean_inputs(x, x_mask, 3 * self.params["d_input"])

        if self.name[:2] == "ss":
            ss8_predictions, ss3_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_cross_entropy_loss(ss8_predictions, y[:,:,:8], y_mask[:,:,:8]) + \
                        compute_cross_entropy_loss(ss3_predictions, y[:,:,30:33], y_mask[:,:,30:33])
            return ss8_predictions, ss3_predictions, loss

        elif self.name[:2] == "pp":
            phipsi_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_mse_loss(phipsi_predictions, y[:, :, 11:15],
                                        y_mask[:, :, 11:15])
            return phipsi_predictions, loss

        elif self.name[:2] == "c2":
            ss8_predictions, ss3_predictions, phipsi_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_cross_entropy_loss(ss8_predictions, y[:,:,:8], y_mask[:,:,:8]) + \
                        compute_cross_entropy_loss(ss3_predictions, y[:,:,30:33], y_mask[:,:,30:33]) + \
                        4*compute_mse_loss(phipsi_predictions, y[:,:,11:15], y_mask[:,:,11:15])
            return ss8_predictions, ss3_predictions, phipsi_predictions, loss

        elif self.name[:2] == "c3":
            ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_cross_entropy_loss(ss8_predictions, y[:,:,:8], y_mask[:,:,:8]) + \
                        compute_cross_entropy_loss(ss3_predictions, y[:,:,30:33], y_mask[:,:,30:33]) + \
                        4*compute_mse_loss(phipsi_predictions, y[:,:,11:15], y_mask[:,:,11:15]) + \
                        0.1*compute_mse_loss(csf_predictions, y[:,:,8:11], y_mask[:,:,8:11])
            return ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions, loss

        elif self.name[:2] == "c4":
            ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions, asa_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_cross_entropy_loss(ss8_predictions, y[:,:,:8], y_mask[:,:,:8]) + \
                        compute_cross_entropy_loss(ss3_predictions, y[:,:,30:33], y_mask[:,:,30:33]) + \
                        4*compute_mse_loss(phipsi_predictions, y[:,:,11:15], y_mask[:,:,11:15]) + \
                        0.1*compute_mse_loss(csf_predictions, y[:,:,8:11], y_mask[:,:,8:11]) + \
                        3*compute_mse_loss(asa_predictions, tf.expand_dims(y[:,:,23],-1), tf.expand_dims(y_mask[:,:,23],-1))
            return ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions, asa_predictions, loss

        elif self.name[:2] == "c5":
            ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions, asa_predictions, rota_predictions = \
                self.birnn(x, x_mask, training=training)
            loss = None
            if training == True:
                loss = compute_cross_entropy_loss(ss8_predictions, y[:,:,:8], y_mask[:,:,:8]) + \
                        compute_cross_entropy_loss(ss3_predictions, y[:,:,30:33], y_mask[:,:,30:33]) + \
                        4*compute_mse_loss(phipsi_predictions, y[:,:,11:15], y_mask[:,:,11:15]) + \
                        0.1*compute_mse_loss(csf_predictions, y[:,:,8:11], y_mask[:,:,8:11]) + \
                        3*compute_mse_loss(asa_predictions, tf.expand_dims(y[:,:,23],-1), tf.expand_dims(y_mask[:,:,23],-1)) + \
                        compute_mse_loss(rota_predictions, y[:,:,15:23], y_mask[:,:,15:23])
            return ss8_predictions, ss3_predictions, phipsi_predictions, csf_predictions, asa_predictions, rota_predictions, loss
Пример #8
0
    def test(self):
        self.G = self.G.eval()
        mse_loss = 0
        sad_loss = 0
        conn_loss = 0
        grad_loss = 0

        test_num = 0
        all_time = 0
        with torch.no_grad():
            for image_dict in self.test_dataloader:
                image, alpha, trimap = image_dict['image'], image_dict[
                    'alpha'], image_dict['trimap']
                alpha_shape, name = image_dict['alpha_shape'], image_dict[
                    'image_name']
                if not self.test_config.cpu:
                    image = image.cuda()
                    alpha = alpha.cuda()
                    trimap = trimap.cuda()
                if self.test_config.fp16:
                    image = image.half()
                    alpha = alpha.half()
                    trimap = trimap.half()
                t1 = time.time()
                if not CONFIG.test.TTA:
                    alpha_pred, _ = self.G(image, trimap)
                else:
                    alpha_pred = torch.zeros_like(image[:, :1, :, :])
                    for rot in range(4):
                        alpha_tmp, _ = self.G(image.rot90(rot, [2, 3]),
                                              trimap.rot90(rot, [2, 3]))
                        alpha_pred += alpha_tmp.rot90(rot, [3, 2])
                        alpha_tmp, _ = self.G(
                            image.flip(3).rot90(rot, [2, 3]),
                            trimap.flip(3).rot90(rot, [2, 3]))
                        alpha_pred += alpha_tmp.rot90(rot, [3, 2]).flip(3)

                    alpha_pred = alpha_pred / 8
                torch.cuda.synchronize()
                all_time += time.time() - t1
                # if self.data_config.extreme_aug:
                #     trimap_reverse = trimap[:,[2,1,0],...]
                #     alpha_reverse, _ = self.G(image, trimap_reverse)
                #     alpha_pred = (alpha_pred + 1 - alpha_reverse) / 2

                if self.model_config.trimap_channel == 3:
                    trimap = trimap.argmax(dim=1, keepdim=True)

                alpha_pred[trimap == 2] = 1
                alpha_pred[trimap == 0] = 0

                trimap[trimap == 2] = 255
                trimap[trimap == 1] = 128

                for cnt in range(image.shape[0]):

                    h, w = alpha_shape
                    test_alpha = alpha[cnt, 0, ...].data.cpu().numpy() * 255
                    test_pred = alpha_pred[cnt, 0,
                                           ...].data.cpu().numpy() * 255
                    test_pred = test_pred.astype(np.uint8)
                    test_trimap = trimap[cnt, 0, ...].data.cpu().numpy()

                    test_pred = test_pred[:h, :w]
                    test_trimap = test_trimap[:h, :w]
                    if self.test_config.alpha_path is not None:
                        cv2.imwrite(
                            os.path.join(
                                self.test_config.alpha_path,
                                os.path.splitext(name[cnt])[0] + ".png"),
                            test_pred)

                    mse_loss += compute_mse_loss(test_pred, test_alpha,
                                                 test_trimap)
                    self.logger.info("{} {}".format(
                        name,
                        comput_sad_loss(test_pred, test_alpha,
                                        test_trimap)[0]))
                    sad_loss += comput_sad_loss(test_pred, test_alpha,
                                                test_trimap)[0]
                    if not self.test_config.fast_eval:
                        # conn_loss += compute_connectivity_error(test_pred, test_alpha, test_trimap, 0.1)
                        grad_loss += compute_gradient_loss(
                            test_pred, test_alpha, test_trimap)

                    test_num += 1

        self.logger.info("TEST NUM: \t\t {}".format(test_num))
        self.logger.info("MSE: \t\t {}".format(mse_loss / test_num))
        self.logger.info("SAD: \t\t {}".format(sad_loss / test_num))
        if not self.test_config.fast_eval:
            self.logger.info("GRAD: \t\t {}".format(grad_loss / test_num))
            # self.logger.info("CONN: \t\t {}".format(conn_loss / test_num))

        self.logger.info("TIME: {}".format(all_time))
Пример #9
0
    def test(self):
        self.G = self.G.eval()
        mse_loss = 0
        sad_loss = 0
        conn_loss = 0
        grad_loss = 0

        test_num = 0

        with torch.no_grad():
            for image_dict in self.test_dataloader:
                image, alpha, trimap = image_dict['image'], image_dict[
                    'alpha'], image_dict['trimap']
                alpha_shape, name = image_dict['alpha_shape'], image_dict[
                    'image_name']
                if not self.test_config.cpu:
                    image = image.cuda()
                    alpha = alpha.cuda()
                    trimap = trimap.cuda()
                alpha_pred, _ = self.G(image, trimap)

                if self.model_config.trimap_channel == 3:
                    trimap = trimap.argmax(dim=1, keepdim=True)

                alpha_pred[trimap == 2] = 1
                alpha_pred[trimap == 0] = 0

                trimap[trimap == 2] = 255
                trimap[trimap == 1] = 128

                for cnt in range(image.shape[0]):

                    h, w = alpha_shape
                    test_alpha = alpha[cnt, 0, ...].data.cpu().numpy() * 255
                    test_pred = alpha_pred[cnt, 0,
                                           ...].data.cpu().numpy() * 255
                    test_pred = test_pred.astype(np.uint8)
                    test_trimap = trimap[cnt, 0, ...].data.cpu().numpy()

                    test_pred = test_pred[:h, :w]
                    test_trimap = test_trimap[:h, :w]

                    if self.test_config.alpha_path is not None:
                        cv2.imwrite(
                            os.path.join(
                                self.test_config.alpha_path,
                                os.path.splitext(name[cnt])[0] + ".png"),
                            test_pred)

                    mse_loss += compute_mse_loss(test_pred, test_alpha,
                                                 test_trimap)
                    print(
                        name,
                        comput_sad_loss(test_pred, test_alpha, test_trimap)[0])
                    sad_loss += comput_sad_loss(test_pred, test_alpha,
                                                test_trimap)[0]
                    if not self.test_config.fast_eval:
                        conn_loss += compute_connectivity_error(
                            test_pred, test_alpha, test_trimap, 0.1)
                        grad_loss += compute_gradient_loss(
                            test_pred, test_alpha, test_trimap)

                    test_num += 1

        self.logger.info("TEST NUM: \t\t {}".format(test_num))
        self.logger.info("MSE: \t\t {}".format(mse_loss / test_num))
        self.logger.info("SAD: \t\t {}".format(sad_loss / test_num))
        if not self.test_config.fast_eval:
            self.logger.info("GRAD: \t\t {}".format(grad_loss / test_num))
            self.logger.info("CONN: \t\t {}".format(conn_loss / test_num))
Пример #10
0
    for img in os.listdir(args.label_dir):
        print(img)
        pred = cv2.imread(
            os.path.join(args.pred_dir, img.replace('.png', '.jpg')),
            0).astype(np.float32)
        label = cv2.imread(os.path.join(args.label_dir, img),
                           0).astype(np.float32)
        detailmap = cv2.imread(
            os.path.join(args.detailmap_dir,
                         img.replace('.png', '_cloud_trimap.jpg')),
            0).astype(np.float32)

        detailmap[detailmap > 0] = 128

        mse_loss_unknown_ = compute_mse_loss(pred, label, detailmap)
        sad_loss_unknown_ = comput_sad_loss(pred, label, detailmap)[0]

        detailmap[...] = 128

        mse_loss_ = compute_mse_loss(pred, label, detailmap)
        sad_loss_ = comput_sad_loss(pred, label, detailmap)[0]

        print('Whole Image: MSE:', mse_loss_, ' SAD:', sad_loss_)
        print('Detail Region: MSE:', mse_loss_unknown_, ' SAD:',
              sad_loss_unknown_)

        mse_loss_unknown.append(mse_loss_unknown_)
        sad_loss_unknown.append(sad_loss_unknown_)

        mse_loss.append(mse_loss_)
Пример #11
0
def video_test(videopath, models_path, videomask_path, is_rotate=False):
    cam = cv.VideoCapture(videopath)
    width = int(cam.get(cv.CAP_PROP_FRAME_WIDTH))
    height = int(cam.get(cv.CAP_PROP_FRAME_HEIGHT))

    cam_mask = cv.VideoCapture(videomask_path)
    width_mask = int(cam_mask.get(cv.CAP_PROP_FRAME_WIDTH))
    height_mask = int(cam_mask.get(cv.CAP_PROP_FRAME_HEIGHT))

    pretrained_path = models_path  #'models/final.87-0.0372.hdf5'         #'models/final.42-0.0398.hdf5'
    encoder_decoder = build_encoder_decoder()
    final = build_refinement(encoder_decoder)
    final.load_weights(pretrained_path)
    print(final.summary())

    tri_videopath = videopath[:-4] + '_tri.mp4'
    tri_video = video_save(tri_videopath, w=width_mask, h=height_mask)

    matting_videopath = videopath[:-4] + '_out.mp4'
    matting_video = video_save(matting_videopath, w=width_mask, h=height_mask)

    comp_videopath = videopath[:-4] + '_comp.mp4'
    comp_video = video_save(comp_videopath, w=width_mask, h=height_mask)

    while (cam.isOpened() and cam_mask.isOpened()):
        start_time = time.time()
        ret, frame = cam.read()
        ret, frame_mask = cam_mask.read()

        if is_rotate:
            frame = imutils.rotate_bound(frame, 90)
            frame_mask = imutils.rotate_bound(frame_mask, 90)
        #             print(frame.shape)
        if frame is None:
            print('Error image!')
            break
        if frame_mask is None:
            print('Error mask image!')
            break

        bg_h, bg_w = height, width
        print('bg_h, bg_w: ' + str((bg_h, bg_w)))

        #         a = get_alpha_test(image_name)
        a = cv.cvtColor(frame_mask, cv.COLOR_BGR2GRAY)
        _, a = cv.threshold(a, 240, 255, cv.THRESH_BINARY)

        a_h, a_w = height_mask, width_mask
        print('a_h, a_w: ' + str((a_h, a_w)))

        alpha = np.zeros((bg_h, bg_w), np.float32)

        alpha[0:a_h, 0:a_w] = a
        trimap = generate_trimap_withmask(alpha)
        #         fg = np.array(np.greater_equal(a, 255).astype(np.float32))
        #         cv.imshow('test_show',fg)
        different_sizes = [(320, 320), (320, 320), (320, 320), (480, 480),
                           (640, 640)]
        crop_size = random.choice(different_sizes)

        bgr_img = frame
        alpha = alpha
        trimap = trimap
        #         cv.imwrite('images/{}_image.png'.format(i), np.array(bgr_img).astype(np.uint8))
        #         cv.imwrite('images/{}_trimap.png'.format(i), np.array(trimap).astype(np.uint8))
        #         cv.imwrite('images/{}_alpha.png'.format(i), np.array(alpha).astype(np.uint8))

        #         x_test = np.empty((1, img_rows, img_cols, 4), dtype=np.float32)
        #         x_test[0, :, :, 0:3] = bgr_img / 255.
        #         x_test[0, :, :, 3] = trimap / 255.

        x_test = np.empty((1, 320, 320, 4), dtype=np.float32)
        bgr_img1 = cv.resize(bgr_img, (320, 320))
        trimap1 = cv.resize(trimap, (320, 320))
        x_test[0, :, :, 0:3] = bgr_img1 / 255.
        x_test[0, :, :, 3] = trimap1 / 255.

        y_true = np.empty((1, img_rows, img_cols, 2), dtype=np.float32)
        y_true[0, :, :, 0] = alpha / 255.
        y_true[0, :, :, 1] = trimap / 255.

        y_pred = final.predict(x_test)
        # print('y_pred.shape: ' + str(y_pred.shape))

        #         y_pred = np.reshape(y_pred, (img_rows, img_cols))
        y_pred = np.reshape(y_pred, (320, 320))
        print(y_pred.shape)
        y_pred = cv.resize(y_pred, (width, height))
        y_pred = y_pred * 255.0
        cv.imshow('pred', y_pred)
        y_pred = get_final_output(y_pred, trimap)

        y_pred = y_pred.astype(np.uint8)

        sad_loss = compute_sad_loss(y_pred, alpha, trimap)
        mse_loss = compute_mse_loss(y_pred, alpha, trimap)
        str_msg = 'sad_loss: %.4f, mse_loss: %.4f, crop_size: %s' % (
            sad_loss, mse_loss, str(crop_size))
        print(str_msg)

        out = y_pred.copy()
        comp = composite_alpha(frame, out)
        draw_str(out, (10, 20), str_msg)

        trimap_show = np.stack((trimap, trimap, trimap), -1)
        out_show = cv.merge((out, out, out))
        #         print(trimap_show.shape,out_show.shape,comp.shape)
        tri_video.write(trimap_show)
        matting_video.write(out_show)
        comp_video.write(comp)
        #         cv.imwrite('images/{}_out.png'.format(filename[6:]), out)

        #### composite background
        #             sample_bg = sample_bgs[i]
        #             bg = cv.imread(os.path.join(bg_test, sample_bg))
        #             bh, bw = bg.shape[:2]
        #             wratio = img_cols / bw
        #             hratio = img_rows / bh
        #             ratio = wratio if wratio > hratio else hratio
        #             if ratio > 1:
        #                 bg = cv.resize(src=bg, dsize=(math.ceil(bw * ratio), math.ceil(bh * ratio)), interpolation=cv.INTER_CUBIC)
        #     #         im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows)
        #             im, bg = composite4(bgr_img, bg, y_pred, img_cols, img_rows)
        #     #         cv.imwrite('images/{}_compose.png'.format(filename[6:]), im)
        #     #         cv.imwrite('images/{}_new_bg.png'.format(i), bg)

        print("Time: {:.2f} s / img".format(time.time() - start_time))

        cv.imshow('out', out)
        cv.imshow('frame', frame)
        cv.imshow('comp', comp)
        cv.imshow('trimap', trimap)

        if cv.waitKey(1) & 0xFF == ord('q'):
            break
    cam.release()
    cam_mask.release()
    tri_video.release()
    matting_video.release()
    comp_video.release()

    cv.destroyAllWindows()