示例#1
0
def test_uv():
    parent = "/home1/qiyuanwang/film_code/eval/deform_unwarp/pred_uv_ind_"
    im1 = "CT500-CT412_12_6-4-2wXldE3caZ0005.exr"
    im2 = "CT500-CT420_1_4-4-svQ0rgJUEu0032.exr"
    im3 = "CT500-CT424_10_6-4-NDz7niawME0010.exr"
    im4 = "CT500-CT426_4_4-4-8pCH2dKNvn0031.exr"

    for im in [im1, im2, im3, im4]:
        print(im)
        uv = cv2.imread(parent + im, cv2.IMREAD_UNCHANGED)
        print(uv.shape)
        print_img_auto(uv, "uv", fname=tfilename("uvtest", im[:-4] + ".jpg"))
示例#2
0
def write_imgs_2(img_tuple,
                 epoch,
                 type_tuple=None,
                 name_tuple=None,
                 training=True):
    print("Writing Images to ", output_dir)
    if training:
        cmap, uv, ab, bg, nor, dep, bg2, ori_gt,\
            cmap_gt, uv_gt, ab_gt, bg_gt, nor_gt, dep_gt = img_tuple
    else:
        cmap, uv, ab, bg, nor, dep, bg2, ori_gt = img_tuple

    print_img_with_reprocess(uv,
                             "uv",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "uv.jpg"))
    print_img_with_reprocess(ab,
                             "ab",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ab.jpg"))
    print_img_with_reprocess(bg,
                             "bg",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "bg.jpg"))
    print_img_with_reprocess(bg2,
                             "bg",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "bg2.jpg"))
    #reprocess_np_auto(cmap, "")
    print_img_with_reprocess(cmap,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "cmap.jpg"))  #
    print_img_with_reprocess(nor,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "nor.jpg"))  #
    print_img_with_reprocess(dep,
                             "exr",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "dep.jpg"))  #
    print_img_with_reprocess(ori_gt,
                             "ori",
                             fname=tfilename(output_dir,
                                             "imgshow/epoch_{}".format(epoch),
                                             "ori_gt.jpg"))

    if training:
        print_img_with_reprocess(
            uv_gt,
            "uv",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "uv_gt.jpg"))
        print_img_with_reprocess(
            ab_gt,
            "ab",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "ab_gt.jpg"))
        print_img_with_reprocess(
            bg_gt,
            "bg",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "bg_gt.jpg"))

        print_img_with_reprocess(
            cmap_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "cmap_gt.jpg"))  #
        print_img_with_reprocess(
            nor_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "nor_gt.jpg"))  #
        print_img_with_reprocess(
            dep_gt,
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "dep_gt.jpg"))  #
        print_img_with_reprocess(
            gt_clip(cmap_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "cmap_gt2.jpg"))  #
        print_img_with_reprocess(
            gt_clip(nor_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "nor_gt2.jpg"))  #
        print_img_with_reprocess(
            gt_clip(dep_gt),
            "exr",
            fname=tfilename(output_dir, "imgshow/epoch_{}".format(epoch),
                            "dep_gt2.jpg"))  #

    uv = reprocess_auto(uv, "uv")
    bg2 = reprocess_auto(bg2, "bg")
    ori_gt = reprocess_auto(ori_gt, "ori")
    bw = uv2backward_trans_3(uv, bg2)
    dewarp = bw_mapping_single_3(ori_gt, bw)

    if training:
        uv_gt = reprocess_auto(uv_gt, "uv")
        bg_gt = reprocess_auto(bg_gt, "bg")
        bw_gt = uv2backward_trans_3(uv_gt, bg_gt)
        bw2 = uv2backward_trans_3(uv, bg_gt)
        dewarp_gt = bw_mapping_single_3(ori_gt, bw_gt)
        dewarp2 = bw_mapping_single_3(ori_gt, bw2)

    print_img_auto(bw,
                   "bw",
                   fname=tfilename(output_dir,
                                   "imgshow/epoch_{}".format(epoch), "bw.jpg"))
    print_img_auto(dewarp,
                   "ori",
                   fname=tfilename(output_dir,
                                   "imgshow/epoch_{}".format(epoch),
                                   "dewarp.jpg"))

    if training:
        print_img_auto(bw_gt,
                       "bw",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "bw_gt.jpg"))
        print_img_auto(bw2,
                       "bw",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "bw2.jpg"))
        print_img_auto(dewarp_gt,
                       "ori",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "dewarp_gt.jpg"))
        print_img_auto(dewarp2,
                       "ori",
                       fname=tfilename(output_dir,
                                       "imgshow/epoch_{}".format(epoch),
                                       "dewarp2.jpg"))
示例#3
0
def test_single(model, imgpath, writer):
    test_name = generate_name()

    img_ori = cv2.imread(imgpath)
    parent, imgname = os.path.split(imgpath)
    img = cv2.resize(img_ori, (256, 256))
    _input = process_auto(img, "ori")
    img_tensor = process_to_tensor(_input[np.newaxis, :, :, :])
    img_tensor = img_tensor.cuda()
    uv_map, cmap, nor_map, alb_map, dep_map, mask_map, \
               _, _, _, _, _, deform_map = model(img_tensor)

    alb_pred = alb_map[0, :, :, :]
    uv_pred = uv_map[0, :, :, :]
    mask_pred = mask_map[0, :, :, :]
    mask_pred = torch.round(mask_pred)
    cmap_pred = cmap[0, :, :, :]
    dep_pred = dep_map[0, :, :, :]
    nor_pred = nor_map[0, :, :, :]
    deform_bw_map = deform2bw_tensor_batch(deform_map.detach().cpu())
    deform_bw_pred = deform_bw_map[0, :, :, :]

    uv_np = reprocess_auto(uv_pred, "uv")
    mask_np = reprocess_auto(mask_pred, "background")
    alb_np = reprocess_auto(alb_pred, "ab")
    # cmap_np = reprocess_auto(cmap_pred, "cmap")

    bw_np = uv2backward_trans_3(uv_np, mask_np)

    dewarp_np = bw_mapping_single_3(img, bw_np)

    bw_large = blur_bw_np(bw_np, img_ori)
    alb_large, diff, ori_gray, img_gray = resize_albedo_np2(
        img_ori, img, alb_np)

    dewarp_ab_large = bw_mapping_single_3(alb_large, bw_large)
    dewarp_large = bw_mapping_single_3(img_ori, bw_large)
    print("shape: ", dewarp_ab_large.shape)

    print_img_auto(ori_gray,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_ori_gray.jpg"))
    print_img_auto(img_gray,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_img_gray.jpg"))
    print_img_auto(alb_np,
                   "ab",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_ab_np.jpg"))
    print_img_auto(uv_np,
                   "uv",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_uv_np.jpg"))
    print_img_auto(bw_np,
                   "bw",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_bw_np.jpg"))
    print_img_auto(img_ori,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_imgori.jpg"))
    print_img_auto(img,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_ori.jpg"))
    print_img_auto(dewarp_np,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_dewarp_ori.jpg"))
    print_img_auto(dewarp_ab_large,
                   "ab",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_large_ab.jpg"))
    print_img_auto(dewarp_large,
                   "ori",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_large_ori.jpg"))
    print_img_auto(alb_large,
                   "ab",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_al_large.jpg"))
    print_img_auto(diff,
                   "ab",
                   fname=tfilename("output_test_single/",
                                   "real-img-" + test_name,
                                   imgname + "_diff_large.jpg"))
示例#4
0
    rgb = torch.from_numpy(rgb.transpose((0, 3, 1, 2))).cuda().float()
    return rgb, pad_img
    # import ipdb; ipdb.set_trace()


@tfuncname
def construct_plain_bg(bs, img_size=256, pad_size=0):
    bg = np.ones((img_size, img_size, 1))
    pad_img = None
    if pad_size > 0:
        BLACK = [0, 0, 0]  # WHITE = [1,1,1]
        pad_img = cv2.copyMakeBorder(rgb.copy(),
                                     pad_size,
                                     pad_size,
                                     pad_size,
                                     pad_size,
                                     cv2.BORDER_CONSTANT,
                                     value=BLACK)
    # import ipdb; ipdb.set_trace()
    bg = bg[np.newaxis, :, :, :]
    bg = np.repeat(bg, bs, axis=0)
    bg = torch.from_numpy(bg.transpose((0, 3, 1, 2))).cuda().float()
    return bg, pad_img


if __name__ == "__main__":
    import cv2
    from dataloader.print_img import print_img_auto
    rgb, _ = construct_plain_cmap(3)
    print_img_auto(rgb[0, :, :, :], "cmap", fname="./plain_cmap_print4.jpg")
示例#5
0
def test(args, model, test_loader, optimizer, criterion, \
    epoch, writer, output_dir, isWriteImage, isVal=False):
    model.eval()
    acc = 0
    device = "cuda"
    count = 0
    mse1, mse2, mse3, mse4 = 0, 0, 0, 0
    mae1, mae2, mae3, mae4 = 0, 0, 0, 0
    cc1, cc2, cc3, cc4 = 0, 0, 0, 0
    psnr1, psnr2, psnr3, psnr4 = 0, 0, 0, 0
    ssim1, ssim2, ssim3, ssim4 = 0, 0, 0, 0
    m1, m2, s1, s2 = 0, 0, 0, 0

    for batch_idx, data in enumerate(test_loader):
        if batch_idx >= 100:
            break
        threeD_map_gt = data[0]
        uv_map_gt = data[1]
        bw_map_gt = data[2]
        mask_map_gt = data[3]
        ori_map_gt = data[4]

        uv_map_gt, threeD_map_gt, bw_map_gt, mask_map_gt = uv_map_gt.to(
            device), threeD_map_gt.to(device), bw_map_gt.to(
                device), mask_map_gt.to(device)
        uv_pred_t, bw_pred_t = model(threeD_map_gt)
        uv_pred_t = torch.where(mask_map_gt > 0, uv_pred_t, mask_map_gt)
        # loss_uv = criterion(uv_pred_t, uv_map_gt).float()
        # loss_bw = criterion(bw_pred_t, bw_map_gt).float()

        uv_np = reprocess_auto_batch(uv_pred_t, "uv")
        uv_gt_np = reprocess_auto_batch(uv_map_gt, "uv")
        bw_np = reprocess_auto_batch(bw_pred_t, "bw")
        bw_gt_np = reprocess_auto_batch(bw_map_gt, "bw")
        mask_np = reprocess_auto_batch(mask_map_gt, "background")
        bw_uv = uv2backward_batch(uv_np, mask_np)
        ori = reprocess_auto_batch(ori_map_gt, "ori")

        count += uv_np.shape[0] * 1.0
        # ----------  MSE  -------------------
        # total_uv_bw_loss, total_bw_loss, total_ori_uv_loss, total_ori_bw_loss
        l1, l2, l3, l4 = uvbw_loss_np_batch(uv_np,
                                            bw_np,
                                            bw_gt_np,
                                            mask_np,
                                            ori,
                                            metrix="mse")
        mse1 += l1
        mse2 += l2
        mse3 += l3
        mse4 += l4
        writer.add_scalar('mse_one/uv_bw_loss', l1, global_step=batch_idx)
        writer.add_scalar('mse_one/bw_loss', l2, global_step=batch_idx)
        writer.add_scalar('mse_one/ori_uv_loss', l3, global_step=batch_idx)
        writer.add_scalar('mse_one/ori_bw_loss', l4, global_step=batch_idx)
        # --------------------------
        l1, l2, l3, l4 = uvbw_loss_np_batch(uv_np,
                                            bw_np,
                                            bw_gt_np,
                                            mask_np,
                                            ori,
                                            metrix="mae")
        mae1 += l1
        mae2 += l2
        mae3 += l3
        mae4 += l4
        writer.add_scalar('mae_one/uv_bw_loss', l1, global_step=batch_idx)
        writer.add_scalar('mae_one/bw_loss', l2, global_step=batch_idx)
        writer.add_scalar('mae_one/ori_uv_loss', l3, global_step=batch_idx)
        writer.add_scalar('mae_one/ori_bw_loss', l4, global_step=batch_idx)
        # ----------  CC  -------------------
        l1, l2, l3, l4 = uvbw_loss_np_batch(uv_np,
                                            bw_np,
                                            bw_gt_np,
                                            mask_np,
                                            ori,
                                            metrix="cc")
        cc1 += l1
        cc2 += l2
        cc3 += l3
        cc4 += l4
        writer.add_scalar('cc_one/uv_bw_loss', l1, global_step=batch_idx)
        writer.add_scalar('cc_one/bw_loss', l2, global_step=batch_idx)
        writer.add_scalar('cc_one/ori_uv_loss', l3, global_step=batch_idx)
        writer.add_scalar('cc_one/ori_bw_loss', l4, global_step=batch_idx)
        # ----------  PSNR  -------------------
        l1, l2, l3, l4 = uvbw_loss_np_batch(uv_np,
                                            bw_np,
                                            bw_gt_np,
                                            mask_np,
                                            ori,
                                            metrix="psnr")
        psnr1 += l1
        psnr2 += l2
        psnr3 += l3
        psnr4 += l4
        writer.add_scalar('psnr_one/uv_bw_loss', l1, global_step=batch_idx)
        writer.add_scalar('psnr_one/bw_loss', l2, global_step=batch_idx)
        writer.add_scalar('psnr_one/ori_uv_loss', l3, global_step=batch_idx)
        writer.add_scalar('psnr_one/ori_bw_loss', l4, global_step=batch_idx)
        # ----------  SSIM  -------------------
        # l1, l2, l3, l4 = uvbw_loss_np_batch(uv_np, bw_np, bw_gt_np, mask_np, ori, metrix="ssim")
        # ssim1 += l1
        # ssim2 += l2
        # ssim3 += l3
        # ssim4 += l4
        writer.add_scalar('ssim_one/uv_bw_loss', l1, global_step=batch_idx)
        writer.add_scalar('ssim_one/bw_loss', l2, global_step=batch_idx)
        writer.add_scalar('ssim_one/ori_uv_loss', l3, global_step=batch_idx)
        writer.add_scalar('ssim_one/ori_bw_loss', l4, global_step=batch_idx)
        print("Batch-idx {},  total_bw_loss {}".format(batch_idx, cc1))
        print("outputdir", output_dir)
        writer.add_scalar('mse/uv_bw_loss',
                          mse1 / count,
                          global_step=batch_idx)
        writer.add_scalar('mse/bw_loss', mse2 / count, global_step=batch_idx)
        writer.add_scalar('mse/ori_uv_loss',
                          mse3 / count,
                          global_step=batch_idx)
        writer.add_scalar('mse/ori_bw_loss',
                          mse4 / count,
                          global_step=batch_idx)
        writer.add_scalar('mae/uv_bw_loss',
                          mae1 / count,
                          global_step=batch_idx)
        writer.add_scalar('mae/bw_loss', mae2 / count, global_step=batch_idx)
        writer.add_scalar('mae/ori_uv_loss',
                          mae3 / count,
                          global_step=batch_idx)
        writer.add_scalar('mae/ori_bw_loss',
                          mae4 / count,
                          global_step=batch_idx)
        writer.add_scalar('cc/uv_bw_loss', cc1 / count, global_step=batch_idx)
        writer.add_scalar('cc/bw_loss', cc2 / count, global_step=batch_idx)
        writer.add_scalar('cc/ori_uv_loss', cc3 / count, global_step=batch_idx)
        writer.add_scalar('cc/ori_bw_loss', cc4 / count, global_step=batch_idx)
        writer.add_scalar('psnr/uv_bw_loss',
                          psnr1 / count,
                          global_step=batch_idx)
        writer.add_scalar('psnr/bw_loss', psnr2 / count, global_step=batch_idx)
        writer.add_scalar('psnr/ori_uv_loss',
                          psnr3 / count,
                          global_step=batch_idx)
        writer.add_scalar('psnr/ori_bw_loss',
                          psnr4 / count,
                          global_step=batch_idx)
        writer.add_scalar('ssim/uv_bw_loss',
                          ssim1 / count,
                          global_step=batch_idx)
        writer.add_scalar('ssim/bw_loss', ssim2 / count, global_step=batch_idx)
        writer.add_scalar('ssim/ori_uv_loss',
                          ssim3 / count,
                          global_step=batch_idx)
        writer.add_scalar('ssim/ori_bw_loss',
                          ssim4 / count,
                          global_step=batch_idx)

        # ------------  Write Imgs --------------------
        dewarp_ori_bw = bw_mapping_batch_3(ori, bw_np, device="cuda")
        dewarp_ori_uv = bw_mapping_batch_3(ori, bw_uv, device="cuda")
        dewarp_ori_gt = bw_mapping_batch_3(ori, bw_gt_np, device="cuda")
        fname_bw = tfilename(output_dir,
                             "test_uvbw/batch_{}/ori_bw.jpg".format(batch_idx))
        fname_uv = tfilename(output_dir,
                             "test_uvbw/batch_{}/ori_uv.jpg".format(batch_idx))
        fname_origt = tfilename(
            output_dir,
            "test_uvbw/batch_{}/ori_dewarp_gt.jpg".format(batch_idx))
        print_img_auto(dewarp_ori_bw[0, :, :, :],
                       img_type="ori",
                       is_gt=False,
                       fname=fname_bw)
        print_img_auto(dewarp_ori_uv[0, :, :, :],
                       img_type="ori",
                       is_gt=False,
                       fname=fname_uv)
        print_img_auto(dewarp_ori_gt[0, :, :, :],
                       img_type="ori",
                       is_gt=False,
                       fname=fname_origt)
        print_img_auto(uv_np[0, :, :, :],
                       "uv",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/uv_pred.jpg".format(batch_idx)))
        print_img_auto(uv_gt_np[0, :, :, :],
                       "uv",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/uv_gt.jpg".format(batch_idx)))
        print_img_auto(bw_np[0, :, :, :],
                       "bw",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/bw_pred.jpg".format(batch_idx)))
        print_img_auto(bw_gt_np[0, :, :, :],
                       "bw",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/bw_gt.jpg".format(batch_idx)))
        print_img_auto(
            bw_uv[0, :, :, :],
            "bw",
            is_gt=False,
            fname=tfilename(
                output_dir,
                "test_uvbw/batch_{}/bw_uv_pred.jpg".format(batch_idx)))
        print_img_auto(mask_np[0, :, :, :],
                       "background",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/bg_gt.jpg".format(batch_idx)))
        print_img_auto(ori[0, :, :, :],
                       "ori",
                       is_gt=False,
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/ori_gt.jpg".format(batch_idx)))

        # Write bw diffs
        diff1 = bw_gt_np[0, :, :, :] - bw_np[0, :, :, :]
        diff2 = bw_gt_np[0, :, :, :] - bw_uv[0, :, :, :]
        max1 = np.max(diff1)
        max2 = np.max(diff2)
        min1 = np.min(diff1)
        min2 = np.min(diff2)
        max_both = np.max([max1, max2])
        min_both = np.min([min1, min2])
        diff_p1 = (diff1 - min_both) / (max_both - min_both) * 255
        diff_p2 = (diff2 - min_both) / (max_both - min_both) * 255
        mean1_1 = np.average(diff1[0, :, :, 0])
        mean1_2 = np.average(diff1[0, :, :, 1])
        mean2_1 = np.average(diff2[0, :, :, 0])
        mean2_2 = np.average(diff2[0, :, :, 1])
        # mean2 = np.average(diff2)
        std1 = np.std(diff1)
        std2 = np.std(diff2)
        m1_1 += np.abs(mean1_1)
        m1_2 += np.abs(mean1_2)
        m2_1 += np.abs(mean2_1)
        m2_2 += np.abs(mean2_2)
        s1 += std1
        s2 += std2

        diff1[0, :, :, 0] = diff1[0, :, :, 0] - mean1_1
        diff1[0, :, :, 1] = diff1[0, :, :, 1] - mean1_2
        diff2[0, :, :, 0] = diff2[0, :, :, 0] - mean2_1
        diff2[0, :, :, 1] = diff2[0, :, :, 1] - mean2_2
        writer.add_scalar("bw_all_single/mean1_1",
                          mean1_1,
                          global_step=batch_idx)
        writer.add_scalar("bw_all_single/mean1_2",
                          mean1_2,
                          global_step=batch_idx)
        writer.add_scalar("bw_all_single/mean2_1",
                          mean2_1,
                          global_step=batch_idx)
        writer.add_scalar("bw_all_single/mean2_2",
                          mean2_2,
                          global_step=batch_idx)
        writer.add_scalar("bw_all_single/std_1", std1, global_step=batch_idx)
        writer.add_scalar("bw_all_single/std_2", std2, global_step=batch_idx)
        writer.add_scalar("bw_all_total/m1_1", m1_1, global_step=batch_idx)
        writer.add_scalar("bw_all_total/m1_2", m1_2, global_step=batch_idx)
        writer.add_scalar("bw_all_total/m2_1", m2_1, global_step=batch_idx)
        writer.add_scalar("bw_all_total/m2_2", m2_2, global_step=batch_idx)
        writer.add_scalar("bw_all_total/mean2", m2, global_step=batch_idx)
        writer.add_scalar("bw_all_total/std_1", s1, global_step=batch_idx)
        writer.add_scalar("bw_all_total/std_2", s2, global_step=batch_idx)

        print_img_auto(diff_p1,
                       "bw",
                       fname=tfilename(
                           output_dir,
                           "test_uvbw/batch_{}/diff_bw.jpg".format(batch_idx)))
        print_img_auto(
            diff_p2,
            "bw",
            fname=tfilename(
                output_dir,
                "test_uvbw/batch_{}/diff_bwuv.jpg".format(batch_idx)))
        print_img_auto(
            diff_m1,
            "bw",
            fname=tfilename(
                output_dir,
                "test_uvbw/batch_{}/diff_m_bw.jpg".format(batch_idx)))
        print_img_auto(
            diff_m2,
            "bw",
            fname=tfilename(
                output_dir,
                "test_uvbw/batch_{}/diff_m_bwuv.jpg".format(batch_idx)))
        # Write diffs Ori
        # diff1 = np.abs(dewarp_ori_bw[0,:,:,:] - dewarp_ori_gt[0,:,:,:])
        # diff2 = np.abs(dewarp_ori_uv[0,:,:,:] - dewarp_ori_gt[0,:,:,:])
        # max1 = np.max(diff1)
        # max2 = np.max(diff2)
        # max_both = np.max([max1, max2])
        # diff1 = diff1 / max_both * 255
        # diff2 = diff2 / max_both * 255
        # mean1 = np.average(diff1)
        # mean2 = np.average(diff2)
        # std1 = np.std(diff2)
        # std2 = np.std(diff2)

        # diff_m1 = diff1 - mean1
        # diff_m2 = diff2 - mean2

        # print_img_auto(diff1, "bw", fname=tfilename(output_dir, "test_uvbw/batch_{}/diff_bw.jpg".format(batch_idx)))
        # print_img_auto(diff2, "bw", fname=tfilename(output_dir, "test_uvbw/batch_{}/diff_bwuv.jpg".format(batch_idx)))
        # print_img_auto(diff_m1, "bw", fname=tfilename(output_dir, "test_uvbw/batch_{}/diff_m_bw.jpg".format(batch_idx)))
        # print_img_auto(diff_m2, "bw", fname=tfilename(output_dir, "test_uvbw/batch_{}/diff_m_bwuv.jpg".format(batch_idx)))

        # print_img_auto(diffo1, "ori", fname=)
        # print_img_auto(bw_gt_np[0,:,:,:], "deform", fname=tfilename(output_dir, "test_uvbw/batch_{}/deform_gt.jpg".format(batch_idx)))
        # print_img_auto(bw_np[0,:,:,:], "deform", fname=tfilename(output_dir, "test_uvbw/batch_{}/deform_bw.jpg".format(batch_idx)))
        # print_img_auto(bw_uv[0,:,:,:], "deform", fname=tfilename(output_dir, "test_uvbw/batch_{}/deform_bwuv.jpg".format(batch_idx)))

    # print("Loss:")
    # print("mse: {} {}".format(mse_bw/count, mse_ori/count))
    # print("cc: {} {}".format(cc_bw/count, cc_ori/count))
    # print("psnr: {} {}".format(psnr_bw/count, psnr_ori/count))
    # print("ssim: {} {}".format(ssim_bw/count, ssim_ori/count))

    return