Exemple #1
0
def test(args, model, device, test_loader, criterion, epoch, writer, output_dir, isWriteImage, sstep):
    print('Testing')
    model.eval()
    test_loss = 0
    correct = 0
    metrics_uv = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_cmap = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_alb = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_dep = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_nor = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_bw = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_deori = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    metrics_dealb = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0}
    with torch.no_grad():
        loss_sum = 0
        loss_sum_alb = 0
        loss_sum_threeD = 0
        loss_sum_uv = 0
        loss_sum_nor = 0
        loss_sum_dep = 0
        loss_sum_mask = 0
        cons_sum_t2n = 0
        cons_sum_t2d = 0
        cons_sum_n2d = 0
        cons_sum_d2n = 0
        loss_sum_tv = 0
        tv_critic = TVLoss().to(device)
        start_time = time.time()
        for batch_idx, data in enumerate(test_loader):
            rgb = data[0]
            alb_map_gt = data[1]
            dep_map_gt = data[2]
            nor_map_gt = data[3]
            threeD_map_gt = data[4]
            uv_map_gt = data[5]
            mask_map_gt = data[6]
            rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, threeD_map_gt, mask_map_gt = rgb.to(
                device), alb_map_gt.to(device), dep_map_gt.to(device), nor_map_gt.to(device), uv_map_gt.to(
                device), threeD_map_gt.to(device), mask_map_gt.to(device)
            uv_map, threeD_map, nor_map, alb_map, dep_map, mask_map, nor_from_threeD, dep_from_threeD, nor_from_dep, dep_from_nor = model(
                rgb)
            bc_critic = nn.BCELoss()
            loss_mask = bc_critic(mask_map, mask_map_gt).float()
            loss_tv = tv_critic(mask_map).float()
            loss_threeD = criterion(threeD_map, threeD_map_gt).float()
            loss_uv = criterion(uv_map, uv_map_gt).float()
            loss_dep = criterion(dep_map, dep_map_gt).float()
            loss_nor = criterion(nor_map, nor_map_gt).float()
            loss_alb = criterion(alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)).float()
            if nor_from_threeD is not None:
                cons_threeD2nor = criterion(nor_from_threeD, nor_map_gt).float()
            else:
                cons_threeD2nor = torch.Tensor([0.]).to(device)
            if dep_from_threeD is not None:
                cons_threeD2dep = criterion(dep_from_threeD, dep_map_gt).float()
            else:
                cons_threeD2dep = torch.Tensor([0.]).to(device)
            if nor_from_dep is not None:
                cons_dep2nor = criterion(nor_from_dep, nor_map_gt).float()
            else:
                cons_dep2nor = torch.Tensor([0.]).to(device)
            if dep_from_nor is not None:
                cons_nor2dep = criterion(dep_from_nor, dep_map_gt).float()
            else:
                cons_nor2dep = torch.Tensor([0.]).to(device)
            test_loss = 4 * loss_uv + 4 * loss_alb + 4 * loss_nor + loss_dep + 2 * loss_mask + loss_threeD + \
                        cons_threeD2nor + cons_threeD2dep + cons_dep2nor + cons_nor2dep + loss_tv
            loss_sum = loss_sum + test_loss
            loss_sum_alb += loss_alb
            loss_sum_threeD += loss_threeD
            loss_sum_uv += loss_uv
            loss_sum_nor += loss_nor
            loss_sum_dep += loss_dep
            loss_sum_mask += loss_mask
            cons_sum_t2n += cons_threeD2nor
            cons_sum_t2d += cons_threeD2dep
            cons_sum_n2d += cons_nor2dep
            cons_sum_d2n += cons_dep2nor
            loss_sum_tv += loss_tv
            if args.calculate_CC:
                alb_map_gt = torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)
                alb_map_gt_recover = re_normalize(alb_map_gt, mean=0.5, std=0.5, inplace=False)
                alb_map_recover = re_normalize(alb_map, mean=0.5, std=0.5, inplace=False)
                metric_op = metrics.film_metrics().to(device)
                metric_alb = metric_op(alb_map_recover, alb_map_gt_recover)
                uv_map_recover = re_normalize(uv_map, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False)
                uv_map_gt_recover = re_normalize(uv_map_gt, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False)
                metric_uv = metric_op(uv_map_recover, uv_map_gt_recover)
                threeD_map_recover = re_normalize(threeD_map, mean=[0.1108, 0.3160, 0.2859],
                                                  std=[0.7065, 0.6840, 0.7141], inplace=False)
                threeD_map_gt_recover = re_normalize(threeD_map_gt, mean=[0.1108, 0.3160, 0.2859],
                                                     std=[0.7065, 0.6840, 0.7141], inplace=False)
                metric_cmap = metric_op(threeD_map_recover, threeD_map_gt_recover)
                nor_map_recover = re_normalize(nor_map, mean=[0.5619, 0.2881, 0.2917],
                                               std=[0.5619, 0.7108, 0.7083], inplace=False)
                nor_map_gt_recover = re_normalize(nor_map_gt, mean=[0.5619, 0.2881, 0.2917],
                                                  std=[0.5619, 0.7108, 0.7083], inplace=False)
                metric_nor = metric_op(nor_map_recover, nor_map_gt_recover)
                dep_map_recover = re_normalize(dep_map, mean=0.5, std=0.5, inplace=False)
                dep_map_gt_recover = re_normalize(dep_map_gt, mean=0.5, std=0.5, inplace=False)
                metric_dep = metric_op(dep_map_recover, dep_map_gt_recover)
                bw_pred = metrics.uv2bmap4d(uv_map, mask_map)
                bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt)
                bw_pred = torch.from_numpy(bw_pred).to(device)
                bw_gt = torch.from_numpy(bw_gt).to(device)
                metric_bw = metric_op(bw_pred, bw_gt)
                dewarp_ori = metrics.bw_mapping4d(bw_pred, rgb, device)
                dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device)
                metric_deori = metric_op(dewarp_ori, dewarp_ori_gt)
                dewarp_ab = metrics.bw_mapping4d(bw_pred, alb_map, device)
                dewarp_ab_gt = metrics.bw_mapping4d(bw_gt, alb_map_gt, device)
                metric_dealb = metric_op(torch.unsqueeze(dewarp_ab, 0), torch.unsqueeze(dewarp_ab_gt, 0))

                metrics_alb = {key: metrics_alb[key] + metric_alb[key] for key in metrics_alb.keys()}
                metrics_uv = {key: metrics_uv[key] + metric_uv[key] for key in metrics_uv.keys()}
                metrics_cmap = {key: metrics_cmap[key] + metric_cmap[key] for key in metrics_cmap.keys()}
                metrics_dep = {key: metrics_dep[key] + metric_dep[key] for key in metrics_dep.keys()}
                metrics_nor = {key: metrics_nor[key] + metric_nor[key] for key in metrics_nor.keys()}
                metrics_bw = {key: metrics_bw[key] + metric_bw[key] for key in metrics_bw.keys()}
                metrics_deori = {key: metrics_deori[key] + metric_deori[key] for key in metrics_deori.keys()}
                metrics_dealb = {key: metrics_dealb[key] + metric_dealb[key] for key in metrics_dealb.keys()}
            if isWriteImage:
                if batch_idx == (len(test_loader.dataset) // args.test_batch_size) - 1:
                    if not os.path.exists(output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)):
                        os.makedirs(output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx))
                    print('writing test image')
                    # for k in range(args.test_batch_size):
                    for k in range(5):
                        albedo_pred = alb_map[k, :, :, :]
                        uv_pred = uv_map[k, :, :, :]
                        back_pred = mask_map[k, :, :, :]
                        back_pred = torch.round(back_pred)
                        cmap_pred = threeD_map[k, :, :, :]
                        depth_pred = dep_map[k, :, :, :]
                        normal_pred = nor_map[k, :, :, :]

                        ori_gt = rgb[k, :, :, :]
                        ab_gt = alb_map_gt[k, :, :, :]
                        uv_gt = uv_map_gt[k, :, :, :]
                        mask_gt = mask_map_gt[k, :, :, :]
                        cmap_gt = threeD_map_gt[k, :, :, :]
                        depth_gt = dep_map_gt[k, :, :, :]
                        normal_gt = nor_map_gt[k, :, :, :]

                        bw_gt = metrics.uv2bmap(uv_gt, mask_gt)
                        bw_pred = metrics.uv2bmap(uv_pred, back_pred)  # [-1,1], [256, 256, 3]
                        dewarp_ori = metrics.bw_mapping(bw_pred, ori_gt, device)
                        dewarp_ab = metrics.bw_mapping(bw_pred, ab_gt, device)
                        dewarp_ori_gt = metrics.bw_mapping(bw_gt, ori_gt, device)

                        output_dir1 = output_dir + 'test/epoch_{}_batch_{}/'.format(epoch, batch_idx)
                        output_uv_pred = output_dir1 + 'pred_uv_ind_{}'.format(k) + '.jpg'
                        output_back_pred = output_dir1 + 'pred_back_ind_{}'.format(k) + '.jpg'
                        output_ab_pred = output_dir1 + 'pred_ab_ind_{}'.format(k) + '.jpg'
                        output_3d_pred = output_dir1 + 'pred_3D_ind_{}'.format(k) + '.jpg'
                        output_bw_pred = output_dir1 + 'pred_bw_ind_{}'.format(k) + '.jpg'
                        output_depth_pred = output_dir1 + 'pred_depth_ind_{}'.format(k) + '.jpg'
                        output_normal_pred = output_dir1 + 'pred_normal_ind_{}'.format(k) + '.jpg'

                        output_ori = output_dir1 + 'gt_ori_ind_{}'.format(k) + '.jpg'
                        output_uv_gt = output_dir1 + 'gt_uv_ind_{}'.format(k) + '.jpg'
                        output_ab_gt = output_dir1 + 'gt_ab_ind_{}'.format(k) + '.jpg'
                        output_cmap_gt = output_dir1 + 'gt_cmap_ind_{}'.format(k) + '.jpg'
                        output_back_gt = output_dir1 + 'gt_back_ind_{}'.format(k) + '.jpg'
                        output_bw_gt = output_dir1 + 'gt_bw_ind_{}'.format(k) + '.jpg'
                        output_dewarp_ori_gt = output_dir1 + 'gt_dewarpOri_ind_{}'.format(k) + '.jpg'
                        output_depth_gt = output_dir1 + 'gt_depth_ind_{}'.format(k) + '.jpg'
                        output_normal_gt = output_dir1 + 'gt_normal_ind_{}'.format(k) + '.jpg'

                        output_dewarp_ori = output_dir1 + 'dewarp_ori_ind_{}'.format(k) + '.jpg'
                        output_dewarp_ab = output_dir1 + 'dewarp_ab_ind_{}'.format(k) + '.jpg'

                        """pred"""
                        write_image_tensor(uv_pred, output_uv_pred, 'std', device=device)
                        write_image_tensor(back_pred, output_back_pred, '01')
                        write_image_tensor(albedo_pred, output_ab_pred, 'std')
                        write_image_tensor(cmap_pred, output_3d_pred, 'gauss', mean=[0.1108, 0.3160, 0.2859],
                                           std=[0.7065, 0.6840, 0.7141])
                        write_image_tensor(depth_pred, output_depth_pred, 'gauss', mean=[0.5], std=[0.5])
                        write_image_tensor(normal_pred, output_normal_pred, 'gauss', mean=[0.5619, 0.2881, 0.2917],
                                           std=[0.5619, 0.7108, 0.7083])
                        write_image_np(bw_pred, output_bw_pred)
                        """gt"""
                        write_image_tensor(ori_gt, output_ori, 'std')
                        write_image_tensor(uv_gt, output_uv_gt, 'std', device=device)
                        write_image_tensor(mask_gt, output_back_gt, '01')
                        write_image_tensor(ab_gt, output_ab_gt, 'std')
                        write_image_tensor(cmap_gt, output_cmap_gt, 'gauss', mean=[0.1108, 0.3160, 0.2859],
                                           std=[0.7065, 0.6840, 0.7141])
                        write_image_tensor(depth_gt, output_depth_gt, 'gauss', mean=[0.5], std=[0.5])
                        write_image_tensor(normal_gt, output_normal_gt, 'gauss', mean=[0.5619, 0.2881, 0.2917],
                                           std=[0.5619, 0.7108, 0.7083])
                        write_image_np(bw_gt, output_bw_gt)

                        write_image(dewarp_ori_gt, output_dewarp_ori_gt)

                        """dewarp"""
                        write_image(dewarp_ori, output_dewarp_ori)
                        write_image(dewarp_ab, output_dewarp_ab)
            if (batch_idx + 1) % 20 == 0:
                print('It cost {} seconds to test {} images'.format(time.time() - start_time,
                                                                    (batch_idx + 1) * args.test_batch_size))
                start_time = time.time()
    test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_alb = loss_sum_alb / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_threeD = loss_sum_threeD / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_uv = loss_sum_uv / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_nor = loss_sum_nor / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_dep = loss_sum_dep / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_mask = loss_sum_mask / (len(test_loader.dataset) / args.test_batch_size)
    test_cons_t2n = cons_sum_t2n / (len(test_loader.dataset) / args.test_batch_size)
    test_cons_t2d = cons_sum_t2d / (len(test_loader.dataset) / args.test_batch_size)
    test_cons_n2d = cons_sum_n2d / (len(test_loader.dataset) / args.test_batch_size)
    test_cons_d2n = cons_sum_d2n / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_tv = loss_tv / (len(test_loader.dataset) / args.test_batch_size)
    print(
        'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.4f}||threeD:{:.4f}||uv:{:.6f}||nor:{:.4f}||dep:{:.4f}||mask:{:.6f}||cons_t2n:{:6f}'
        'cons_t2d:{:.6f}||cons_n2d:{:.6f}||cons_d2n:{:.6f}||loss_tv:{:.6f}'.format(epoch, batch_idx + 1,
                                                                                   len(
                                                                                       test_loader.dataset) // args.batch_size,
                                                                                   test_loss.item(),
                                                                                   test_loss_alb.item(),
                                                                                   test_loss_threeD.item(),
                                                                                   test_loss_uv.item(),
                                                                                   test_loss_nor.item(),
                                                                                   test_loss_dep.item(),
                                                                                   test_loss_mask.item(),
                                                                                   test_cons_t2n.item(),
                                                                                   test_cons_t2d.item(),
                                                                                   test_cons_n2d.item(),
                                                                                   test_cons_d2n.item(),
                                                                                   test_loss_tv.item()))
    if args.calculate_CC:
        num_iters = math.ceil(len(test_loader.dataset) / args.test_batch_size)
        metrics_alb = {key: metrics_alb[key] / num_iters for key in metrics_alb.keys()}
        metrics_uv = {key: metrics_uv[key] / num_iters for key in metrics_uv.keys()}
        metrics_cmap = {key: metrics_cmap[key] / num_iters for key in metrics_cmap.keys()}
        metrics_dep = {key: metrics_dep[key] / num_iters for key in metrics_dep.keys()}
        metrics_nor = {key: metrics_nor[key] / num_iters for key in metrics_nor.keys()}
        metrics_bw = {key: metrics_bw[key] / num_iters for key in metrics_bw.keys()}
        metrics_deori = {key: metrics_deori[key] / num_iters for key in metrics_deori.keys()}
        metrics_dealb = {key: metrics_dealb[key] / num_iters for key in metrics_dealb.keys()}
    if args.calculate_CC:
        for key in metrics_alb.keys():
            print(str(key) + '_uv:{:.6f}\t' + str(key) + '_dep:{:.6f}\t' + str(key) + '_nor:{:.6f}\t' + str(key) +
                  '_cmap:{:.6f}\t' + str(key) + '_alb:{:.6f}\t' + str(key) + '_bw:{:.6f}\t' + str(
                key) + '_deori:{:.6f}\t' +
                  str(key) + '_dealb:{:.6f}'.format(metrics_uv[key], metrics_dep[key], metrics_nor[key],
                                                    metrics_cmap[key], metrics_alb[key],
                                                    metrics_bw[key], metrics_deori[key], metrics_dealb[key]))
    if args.write_txt:
        txt_dir = 'output_txt/' + args.model_name + '.txt'
        f = open(txt_dir, 'a')
        f.write(
            'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t uv: {:.6f}, \t normal: {:.4f}, \t depth: {:.4f}, \t back: {:.6f} , \t constrain 3d to normal: {:.4f}, \t constrain 3d to depth: {:.4f}, \t constrain normal to depth: {:.4f}, \t constrain depth to normal: {:.4f}, \t loss tv: {:.6f}\n'.format(
                epoch, test_loss.item(),
                test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(), test_loss_nor.item(),
                test_loss_dep.item(), test_loss_mask.item(), test_cons_t2n.item(), test_cons_t2d.item(),
                test_cons_n2d.item(), test_cons_d2n.item(), test_loss_tv.item()))
        for key in metrics_alb.keys():
            f.write(str(key) + '_uv:{:.6f}\t' + str(key) + '_dep:{:.6f}\t' + str(key) + '_nor:{:.6f}\t' + str(key) +
                    '_cmap:{:.6f}\t' + str(key) + '_alb:{:.6f}\t' + str(key) + '_bw:{:.6f}\t' + str(
                key) + '_deori:{:.6f}\t' +
                    str(key) + '_dealb:{:.6f}\n'.format(metrics_uv[key], metrics_dep[key], metrics_nor[key],
                                                        metrics_cmap[key], metrics_alb[key],
                                                        metrics_bw[key], metrics_deori[key], metrics_dealb[key]))
        f.close()
    if args.write_summary:
        print('sstep', sstep)
        # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1)
        writer.add_scalar('summary/test_loss', test_loss.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_ab', test_loss_alb.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_cmap', test_loss_threeD.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_uv', test_loss_uv.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_normal', test_loss_nor.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_depth', test_loss_dep.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_back', test_loss_mask.item(), global_step=sstep)
        writer.add_scalar('summary/test_con_3d2nor', test_cons_t2n.item(), global_step=sstep)
        writer.add_scalar('summary/test_con_3d2dep', test_cons_t2d.item(), global_step=sstep)
        writer.add_scalar('summary/test_con_nor2dep', test_cons_n2d.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_dep2nor', test_cons_d2n.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_tv', test_loss_tv.item(), global_step=sstep)
Exemple #2
0
def test(args, model, device, test_loader, criterion, epoch, writer,
         output_dir, isWriteImage, sstep):
    print('Testing')
    model.eval()
    test_loss = 0
    correct = 0
    cc_uv = 0
    cc_threeD = 0
    cc_alb = 0
    cc_bw = 0
    cc_dewarp_ori = 0
    cc_dewarp_alb = 0
    with torch.no_grad():
        loss_sum = 0
        loss_sum_alb = 0
        loss_sum_threeD = 0
        loss_sum_uv = 0
        loss_sum_nor = 0
        loss_sum_dep = 0
        loss_sum_mask = 0
        cons_sum_t2n = 0
        cons_sum_t2d = 0
        cons_sum_n2d = 0
        cons_sum_d2n = 0
        start_time = time.time()
        for batch_idx, data in enumerate(test_loader):
            rgb = data[0]
            alb_map_gt = data[1]
            dep_map_gt = data[2]
            nor_map_gt = data[3]
            threeD_map_gt = data[4]
            uv_map_gt = data[5]
            mask_map_gt = data[6]
            rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, threeD_map_gt, mask_map_gt = rgb.to(
                device), alb_map_gt.to(device), dep_map_gt.to(
                    device), nor_map_gt.to(device), uv_map_gt.to(
                        device), threeD_map_gt.to(device), mask_map_gt.to(
                            device)
            uv_map, threeD_map, nor_map, alb_map, dep_map, mask_map, nor_from_threeD, dep_from_threeD, nor_from_dep, dep_from_nor = model(
                rgb)
            loss_mask = criterion(mask_map, mask_map_gt).float()
            loss_threeD = criterion(threeD_map, threeD_map_gt).float()
            loss_uv = criterion(uv_map, uv_map_gt).float()
            loss_dep = criterion(dep_map, dep_map_gt).float()
            loss_nor = criterion(nor_map, nor_map_gt).float()
            loss_alb = criterion(alb_map,
                                 torch.unsqueeze(alb_map_gt[:, 0, :, :],
                                                 1)).float()
            if nor_from_threeD is not None:
                cons_threeD2nor = criterion(nor_from_threeD,
                                            nor_map_gt).float()
            else:
                cons_threeD2nor = torch.Tensor([0.]).to(device)
            if dep_from_threeD is not None:
                cons_threeD2dep = criterion(dep_from_threeD,
                                            dep_map_gt).float()
            else:
                cons_threeD2dep = torch.Tensor([0.]).to(device)
            if nor_from_dep is not None:
                cons_dep2nor = criterion(nor_from_dep, nor_map_gt).float()
            else:
                cons_dep2nor = torch.Tensor([0.]).to(device)
            if dep_from_nor is not None:
                cons_nor2dep = criterion(dep_from_nor, dep_map_gt).float()
            else:
                cons_nor2dep = torch.Tensor([0.]).to(device)
            test_loss = 4 * loss_uv + 4 * loss_alb + loss_nor + loss_dep + 2 * loss_mask + loss_threeD + \
                   cons_threeD2nor + cons_threeD2dep + cons_dep2nor + cons_nor2dep
            loss_sum = loss_sum + test_loss
            loss_sum_alb += loss_alb
            loss_sum_threeD += loss_threeD
            loss_sum_uv += loss_uv
            loss_sum_nor += loss_nor
            loss_sum_dep += loss_dep
            loss_sum_mask += loss_mask
            cons_sum_t2n += cons_threeD2nor
            cons_sum_t2d += cons_threeD2dep
            cons_sum_n2d += cons_nor2dep
            cons_sum_d2n += cons_dep2nor
            if args.calculate_CC:
                c_alb = metrics.calculate_CC_metrics(
                    alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1))
                c_uv = metrics.calculate_CC_metrics(uv_map, uv_map_gt)
                c_threeD = metrics.calculate_CC_metrics(
                    threeD_map, threeD_map_gt)
                bw_pred = metrics.uv2bmap4d(uv_map, mask_map)
                bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt)
                c_bw = metrics.calculate_CC_metrics(bw_pred, bw_gt)

                dewarp_ori = metrics.bw_mapping4d(bw_pred, rgb, device)
                dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device)
                c_dewarp_ori = metrics.calculate_CC_metrics(
                    dewarp_ori, dewarp_ori_gt)
                dewarp_ab = metrics.bw_mapping4d(bw_pred, alb_map, device)
                dewarp_ab_gt = metrics.bw_mapping4d(
                    bw_gt, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1), device)
                c_dewarp_alb = metrics.calculate_CC_metrics(
                    torch.unsqueeze(dewarp_ab, 0),
                    torch.unsqueeze(dewarp_ab_gt, 0))

                cc_alb += c_alb
                cc_uv += c_uv
                cc_threeD += c_threeD
                cc_bw += c_bw
                cc_dewarp_ori += c_dewarp_ori
                cc_dewarp_alb += c_dewarp_alb
            if isWriteImage:
                if batch_idx == (len(test_loader.dataset) //
                                 args.test_batch_size) - 1:
                    if not os.path.exists(
                            output_dir +
                            'test/epoch_{}_batch_{}'.format(epoch, batch_idx)):
                        os.makedirs(
                            output_dir +
                            'test/epoch_{}_batch_{}'.format(epoch, batch_idx))
                    print('writing test image')
                    for k in range(args.test_batch_size):
                        # print('k', k)
                        albedo_pred = alb_map[k, :, :, :]
                        uv_pred = uv_map[k, :, :, :]
                        back_pred = mask_map[k, :, :, :]
                        cmap_pred = threeD_map[k, :, :, :]
                        depth_pred = dep_map[k, :, :, :]
                        normal_pred = nor_map[k, :, :, :]

                        ori_gt = rgb[k, :, :, :]
                        ab_gt = alb_map_gt[k, :, :, :]
                        uv_gt = uv_map_gt[k, :, :, :]
                        mask_gt = mask_map_gt[k, :, :, :]
                        cmap_gt = threeD_map_gt[k, :, :, :]
                        depth_gt = dep_map_gt[k, :, :, :]
                        normal_gt = nor_map_gt[k, :, :, :]

                        bw_gt = metrics.uv2bmap(uv_gt, mask_gt)
                        bw_pred = metrics.uv2bmap(
                            uv_pred, back_pred)  # [-1,1], [256, 256, 3]

                        # bw_gt = bmap[k, :, :, :]

                        dewarp_ori = metrics.bw_mapping(
                            bw_pred, ori_gt, device)
                        dewarp_ab = metrics.bw_mapping(bw_pred, ab_gt, device)
                        dewarp_ori_gt = metrics.bw_mapping(
                            bw_gt, ori_gt, device)

                        output_dir1 = output_dir + 'test/epoch_{}_batch_{}/'.format(
                            epoch, batch_idx)
                        output_uv_pred = output_dir1 + 'pred_uv_ind_{}'.format(
                            k) + '.jpg'
                        output_back_pred = output_dir1 + 'pred_back_ind_{}'.format(
                            k) + '.jpg'
                        output_ab_pred = output_dir1 + 'pred_ab_ind_{}'.format(
                            k) + '.jpg'
                        output_3d_pred = output_dir1 + 'pred_3D_ind_{}'.format(
                            k) + '.jpg'
                        output_bw_pred = output_dir1 + 'pred_bw_ind_{}'.format(
                            k) + '.jpg'
                        output_depth_pred = output_dir1 + 'pred_depth_ind_{}'.format(
                            k) + '.jpg'
                        output_normal_pred = output_dir1 + 'pred_normal_ind_{}'.format(
                            k) + '.jpg'

                        output_ori = output_dir1 + 'gt_ori_ind_{}'.format(
                            k) + '.jpg'
                        output_uv_gt = output_dir1 + 'gt_uv_ind_{}'.format(
                            k) + '.jpg'
                        output_ab_gt = output_dir1 + 'gt_ab_ind_{}'.format(
                            k) + '.jpg'
                        output_cmap_gt = output_dir1 + 'gt_cmap_ind_{}'.format(
                            k) + '.jpg'
                        output_back_gt = output_dir1 + 'gt_back_ind_{}'.format(
                            k) + '.jpg'
                        output_bw_gt = output_dir1 + 'gt_bw_ind_{}'.format(
                            k) + '.jpg'
                        output_dewarp_ori_gt = output_dir1 + 'gt_dewarpOri_ind_{}'.format(
                            k) + '.jpg'
                        output_depth_gt = output_dir1 + 'gt_depth_ind_{}'.format(
                            k) + '.jpg'
                        output_normal_gt = output_dir1 + 'gt_normal_ind_{}'.format(
                            k) + '.jpg'

                        output_dewarp_ori = output_dir1 + 'dewarp_ori_ind_{}'.format(
                            k) + '.jpg'
                        output_dewarp_ab = output_dir1 + 'dewarp_ab_ind_{}'.format(
                            k) + '.jpg'
                        """pred"""
                        write_image_tensor(uv_pred,
                                           output_uv_pred,
                                           'std',
                                           device=device)
                        write_image_tensor(back_pred, output_back_pred, '01')
                        write_image_tensor(albedo_pred, output_ab_pred, 'std')
                        write_image_tensor(cmap_pred,
                                           output_3d_pred,
                                           'gauss',
                                           mean=[0.100, 0.326, 0.289],
                                           std=[0.096, 0.332, 0.298])
                        write_image_tensor(depth_pred,
                                           output_depth_pred,
                                           'gauss',
                                           mean=[0.316],
                                           std=[0.309])
                        write_image_tensor(normal_pred,
                                           output_normal_pred,
                                           'gauss',
                                           mean=[0.584, 0.294, 0.300],
                                           std=[0.483, 0.251, 0.256])
                        write_image_np(bw_pred, output_bw_pred)
                        """gt"""
                        write_image_tensor(ori_gt, output_ori, 'std')
                        write_image_tensor(uv_gt,
                                           output_uv_gt,
                                           'std',
                                           device=device)
                        write_image_tensor(mask_gt, output_back_gt, '01')
                        write_image_tensor(ab_gt, output_ab_gt, 'std')
                        write_image_tensor(cmap_gt,
                                           output_cmap_gt,
                                           'gauss',
                                           mean=[0.100, 0.326, 0.289],
                                           std=[0.096, 0.332, 0.298])
                        write_image_tensor(depth_gt,
                                           output_depth_gt,
                                           'gauss',
                                           mean=[0.316],
                                           std=[0.309])
                        write_image_tensor(normal_gt,
                                           output_normal_gt,
                                           'gauss',
                                           mean=[0.584, 0.294, 0.300],
                                           std=[0.483, 0.251, 0.256])

                        write_image_np(bw_gt, output_bw_gt)

                        write_image(dewarp_ori_gt, output_dewarp_ori_gt)
                        """dewarp"""
                        write_image(dewarp_ori, output_dewarp_ori)
                        write_image(dewarp_ab, output_dewarp_ab)
            if (batch_idx + 1) % 20 == 0:
                print('It cost {} seconds to test {} images'.format(
                    time.time() - start_time,
                    (batch_idx + 1) * args.test_batch_size))
                start_time = time.time()
    test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_alb = loss_sum_alb / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_loss_threeD = loss_sum_threeD / (len(test_loader.dataset) /
                                          args.test_batch_size)
    test_loss_uv = loss_sum_uv / (len(test_loader.dataset) /
                                  args.test_batch_size)
    test_loss_nor = loss_sum_nor / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_loss_dep = loss_sum_dep / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_loss_mask = loss_sum_mask / (len(test_loader.dataset) /
                                      args.test_batch_size)
    test_cons_t2n = cons_sum_t2n / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_cons_t2d = cons_sum_t2d / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_cons_n2d = cons_sum_n2d / (len(test_loader.dataset) /
                                    args.test_batch_size)
    test_cons_d2n = cons_sum_d2n / (len(test_loader.dataset) /
                                    args.test_batch_size)
    print(
        'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.4f}||threeD:{:.4f}||uv:{:.6f}||nor:{:.4f}||dep:{:.4f}||mask:{:.6f}||cons_t2n:{:6f}'
        'cons_t2d:{:.6f}||cons_n2d:{:.6f}||cons_d2n:{:.6f}'.format(
            epoch, batch_idx + 1,
            len(test_loader.dataset) // args.batch_size, test_loss.item(),
            test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(),
            test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(),
            test_cons_t2n.item(), test_cons_t2d.item(), test_cons_n2d.item(),
            test_cons_d2n.item()))
    if args.calculate_CC:
        cc_uv = cc_uv / (len(test_loader.dataset) / args.test_batch_size)
        cc_cmap = cc_threeD / (len(test_loader.dataset) / args.test_batch_size)
        cc_ab = cc_alb / (len(test_loader.dataset) / args.test_batch_size)
        cc_bw = cc_bw / (len(test_loader.dataset) / args.test_batch_size)
        cc_dewarp_ori = cc_dewarp_ori / (len(test_loader.dataset) /
                                         args.test_batch_size)
        cc_dewarp_ab = cc_dewarp_alb / (len(test_loader.dataset) /
                                        args.test_batch_size)
    if args.calculate_CC:
        print(
            'CC_uv: {}\t CC_cmap: {}\t CC_ab: {}\t CC_bw: {}\t CC_dewarp_ori: {}\t CC_dewarp_ab: {}'
            .format(cc_uv, cc_cmap, cc_ab, cc_bw, cc_dewarp_ori, cc_dewarp_ab))
    if args.write_txt:
        txt_dir = 'output_txt/' + args.model_name + '.txt'
        f = open(txt_dir, 'a')
        f.write(
            'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t uv: {:.6f}, \t normal: {:.4f}, \t depth: {:.4f}, \t back: {:.6f} , \t constrain 3d to normal: {:.4f}, \t constrain 3d to depth: {:.4f}, \t constrain normal to depth: {:.4f}, \t constrain depth to normal: {:.4f}, CC_uv: {}\t CC_cmap: {}\t CC_ab: {}\t CC_bw: {}\t CC_dewarp_ori: {}\t CC_dewarp_ab: {}\n'
            .format(epoch, test_loss.item(), test_loss_alb.item(),
                    test_loss_threeD.item(), test_loss_uv.item(),
                    test_loss_nor.item(), test_loss_dep.item(),
                    test_loss_mask.item(), test_cons_t2n.item(),
                    test_cons_t2d.item(), test_cons_n2d.item(),
                    test_cons_d2n.item(), cc_uv, cc_cmap, cc_ab, cc_bw,
                    cc_dewarp_ori, cc_dewarp_ab))
        f.close()
    if args.write_summary:
        print('sstep', sstep)
        # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1)
        writer.add_scalar('summary/test_loss',
                          test_loss.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_ab',
                          test_loss_alb.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_cmap',
                          test_loss_threeD.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_uv',
                          test_loss_uv.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_normal',
                          test_loss_nor.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_depth',
                          test_loss_dep.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_back',
                          test_loss_mask.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_con_3d2nor',
                          test_cons_t2n.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_con_3d2dep',
                          test_cons_t2d.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_con_nor2dep',
                          test_cons_n2d.item(),
                          global_step=sstep)
        writer.add_scalar('summary/test_loss_dep2nor',
                          test_cons_d2n.item(),
                          global_step=sstep)
Exemple #3
0
def test(args, model, device, test_loader, criterion, epoch, writer, output_dir_test, isWriteImage, sstep):
    print('Testing')
    model.eval()
    metrics_cmap = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_alb = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_dep = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_nor = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_deform_bw = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_deform_ori = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    metrics_deform_alb = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0}
    with torch.no_grad():
        loss_sum = 0
        loss_sum_alb = 0
        loss_sum_cmap = 0
        loss_sum_nor = 0
        loss_sum_dep = 0
        loss_sum_mask = 0
        loss_sum_smooth = 0
        dice_sum_metric = 0
        bc_critic = nn.BCELoss()
        start_time = time.time()
        for batch_idx, data in enumerate(test_loader):
            rgb = data[0]
            alb_map_gt = data[1]
            dep_map_gt = data[2]
            nor_map_gt = data[3]
            cmap_map_gt = data[4]
            uv_map_gt = data[5]
            mask_map_gt = data[6]
            bw_map_gt = data[7]
            deform_map_gt = data[8]
            names = data[9]
            rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, cmap_map_gt, mask_map_gt, bw_map_gt, deform_map_gt = \
                rgb.to(device), alb_map_gt.to(device), dep_map_gt.to(device), nor_map_gt.to(device), uv_map_gt.to(
                    device), \
                cmap_map_gt.to(device), mask_map_gt.to(device), bw_map_gt.to(device), deform_map_gt.to(device)
            cmap, nor_map, alb_map, dep_map, mask_map, deform_map = model(rgb)
            #bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt)
            #dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device)
            dewarp_ori_gt = metrics.bw_mapping4d(bw_map_gt, rgb, device)
            ######################################################################################################################
            #                                                  TEST LOSS SUMMARY
            ######################################################################################################################
            loss_mask = bc_critic(mask_map, mask_map_gt).float()
            dice_metric = diceCoeff(mask_map, mask_map_gt).float()
            loss_cmap = criterion(cmap, cmap_map_gt).float()
            loss_dep = criterion(dep_map, dep_map_gt).float()
            loss_nor = criterion(nor_map, nor_map_gt).float()
            loss_alb = criterion(alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)).float()
            loss_smooth = criterion(deform_map, deform_map_gt)
            test_loss = 4 * loss_alb + 4 * loss_nor + loss_dep + 2 * loss_mask + loss_cmap + \
                        loss_smooth
            loss_sum = loss_sum + test_loss
            loss_sum_alb += loss_alb
            loss_sum_cmap += loss_cmap
            loss_sum_nor += loss_nor
            loss_sum_dep += loss_dep
            loss_sum_mask += loss_mask
            loss_sum_smooth += loss_smooth
            dice_sum_metric += dice_metric
            alb_map_gt = torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)
            if args.calculate_CC:
                metric_op = metrics.film_metrics_ncc().to(device)
                # ALB
                alb_map_gt_recover = torch.clamp(re_normalize(alb_map_gt, mean=0.5, std=0.5, inplace=False), 0., 1.)
                alb_map_recover = torch.clamp(re_normalize(alb_map, mean=0.5, std=0.5, inplace=False), 0., 1.)
                metric_alb = metric_op(alb_map_recover, alb_map_gt_recover)
                # UV
                uv_map_gt_recover = torch.clamp(re_normalize(uv_map_gt, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False),
                                                0., 1.)
                # CMAP
                cmap_recover = torch.clamp(
                    re_normalize(cmap, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141], inplace=False), 0.,
                    1.)
                cmap_gt_recover = torch.clamp(
                    re_normalize(cmap_map_gt, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141],
                                 inplace=False),
                    0., 1.)
                metric_cmap = metric_op(cmap_recover, cmap_gt_recover)
                # Normal
                nor_map_recover = torch.clamp(
                    re_normalize(nor_map, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083], inplace=False),
                    0., 1.)
                nor_map_gt_recover = torch.clamp(
                    re_normalize(nor_map_gt, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083],
                                 inplace=False), 0., 1.)
                metric_nor = metric_op(nor_map_recover, nor_map_gt_recover)
                # Depth
                dep_map_recover = torch.clamp(re_normalize(dep_map, mean=0.5, std=0.5, inplace=False), 0., 1.)
                dep_map_gt_recover = torch.clamp(re_normalize(dep_map_gt, mean=0.5, std=0.5, inplace=False), 0., 1.)
                metric_dep = metric_op(dep_map_recover, dep_map_gt_recover)
                # UV2BW
                bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt)
                bw_gt_recover = torch.clamp(
                    re_normalize(torch.from_numpy(np.transpose(bw_gt, (0, 3, 1, 2))).to(device), mean=[0.5, 0.5],
                                 std=[0.5, 0.5]), 0., 1.)
                # DEFORM
                deform_bw_pred = deform2bw_tensor_batch(deform_map.clone(), device)
                deform_bw_pred_recover = torch.clamp(
                    re_normalize(deform_bw_pred, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False), 0., 1.)
                metric_deform_bw = metric_op(deform_bw_pred_recover.float().contiguous(),
                                             bw_gt_recover.float().contiguous())
                # UV2BW Dewarp ORI
                dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device)
                dewarp_ori_gt_recover = torch.clamp(
                    re_normalize(dewarp_ori_gt.transpose(2, 3).transpose(1, 2), mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5], inplace=False), 0., 1.)
                # UV2BW Dewarp ALB
                dewarp_alb_gt = metrics.bw_mapping4d(bw_gt, alb_map_gt, device)
                dewarp_alb_gt_recover = torch.clamp(re_normalize(dewarp_alb_gt, mean=[0.5], std=[0.5], inplace=False),
                                                    0., 1.)
                # DEFORM DEWARP ORI
                deform_dewarp_ori = metrics.bw_mapping4d(deform_bw_pred, rgb, device)
                deform_dewarp_ori_recover = torch.clamp(
                    re_normalize(deform_dewarp_ori.transpose(2, 3).transpose(1, 2), mean=[0.5, 0.5, 0.5],
                                 std=[0.5, 0.5, 0.5],
                                 inplace=False), 0., 1.)
                metric_deform_ori = metric_op(deform_dewarp_ori_recover.contiguous(),
                                              dewarp_ori_gt_recover.contiguous())
                # DEFORM DEWARP ALB
                deform_dewarp_alb = metrics.bw_mapping4d(deform_bw_pred, alb_map, device)
                deform_dewarp_alb_recover = torch.clamp(
                    re_normalize(deform_dewarp_alb, mean=[0.5], std=[0.5], inplace=False), 0., 1.)
                metric_deform_alb = metric_op(torch.unsqueeze(deform_dewarp_alb_recover, 1).contiguous(),
                                              torch.unsqueeze(dewarp_alb_gt_recover, 1).contiguous())
                # SUMMARY
                metrics_alb = {key: metrics_alb[key] + metric_alb[key] for key in metrics_alb.keys()}
                metrics_cmap = {key: metrics_cmap[key] + metric_cmap[key] for key in metrics_cmap.keys()}
                metrics_dep = {key: metrics_dep[key] + metric_dep[key] for key in metrics_dep.keys()}
                metrics_nor = {key: metrics_nor[key] + metric_nor[key] for key in metrics_nor.keys()}
                metrics_deform_bw = {key: metrics_deform_bw[key] + metric_deform_bw[key] for key in
                                     metrics_deform_bw.keys()}
                metrics_deform_ori = {key: metrics_deform_ori[key] + metric_deform_ori[key] for key in
                                      metrics_deform_ori.keys()}
                metrics_deform_alb = {key: metrics_deform_alb[key] + metric_deform_alb[key] for key in
                                      metrics_deform_alb.keys()}
            if isWriteImage:
                if batch_idx == (len(test_loader.dataset) // args.test_batch_size) - 2:
                    if not os.path.exists(output_dir_test + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)):
                        os.makedirs(output_dir_test + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx))
                    print('writing test image')
                    for k in range(args.test_batch_size):
                        alb_pred = alb_map[k, :, :, :]
                        mask_pred = mask_map[k, :, :, :]
                        mask_pred = torch.round(mask_pred)
                        cmap_pred = cmap[k, :, :, :]
                        dep_pred = dep_map[k, :, :, :]
                        nor_pred = nor_map[k, :, :, :]
                        deform_bw_map = deform2bw_tensor_batch(deform_map.clone().to(device), device)
                        deform_bw_pred = deform_bw_map[k, :, :, :]
                        name = names[k]
                        ori_gt = rgb[k, :, :, :]
                        alb_gt = alb_map_gt[k, :, :, :]
                        uv_gt = uv_map_gt[k, :, :, :]
                        mask_gt = mask_map_gt[k, :, :, :]
                        cmap_gt = cmap_map_gt[k, :, :, :]
                        dep_gt = dep_map_gt[k, :, :, :]
                        nor_gt = nor_map_gt[k, :, :, :]

                        bw_gt = metrics.uv2bmap(uv_gt, mask_gt)
                        dewarp_ori_gt = metrics.bw_mapping(bw_gt, ori_gt, device)
                        dewarp_alb_gt = metrics.bw_mapping(bw_gt, alb_gt, device)
                        deform_dewarp_ori = metrics.bw_mapping(deform_bw_pred, ori_gt, device)
                        deform_dewarp_alb = metrics.bw_mapping(deform_bw_pred, alb_pred, device)
                        output_dir = os.path.join(output_dir_test, 'test/epoch_{}_batch_{}/'.format(epoch, batch_idx))
                        output_mask_pred = os.path.join(output_dir, 'pred_mask_ind_{}'.format(name) + '.jpg')
                        output_alb_pred = os.path.join(output_dir, 'pred_alb_ind_{}'.format(name) + '.jpg')
                        output_cmap_pred = os.path.join(output_dir, 'pred_3D_ind_{}'.format(name) + '.jpg')
                        output_deform_bw_pred = os.path.join(output_dir, 'pred_deform_bw_ind_{}'.format(name) + '.exr')
                        output_dep_pred = os.path.join(output_dir, 'pred_dep_ind_{}'.format(name) + '.jpg')
                        output_nor_pred = os.path.join(output_dir, 'pred_normal_ind_{}'.format(name) + '.jpg')
                        output_ori = os.path.join(output_dir, 'gt_ori_ind_{}'.format(name) + '.jpg')
                        output_uv_gt = os.path.join(output_dir, 'gt_uv_ind_{}'.format(name) + '.exr')
                        output_alb_gt = os.path.join(output_dir, 'gt_alb_ind_{}'.format(name) + '.jpg')
                        output_cmap_gt = os.path.join(output_dir, 'gt_cmap_ind_{}'.format(name) + '.jpg')
                        output_mask_gt = os.path.join(output_dir, 'gt_mask_ind_{}'.format(name) + '.jpg')
                        output_bw_gt = os.path.join(output_dir, 'gt_bw_ind_{}'.format(name) + '.exr')
                        output_dewarp_ori_gt = os.path.join(output_dir, 'gt_dewarpOri_ind_{}'.format(name) + '.jpg')
                        output_dewarp_alb_gt = os.path.join(output_dir, 'gt_dewarpAlb_ind_{}'.format(name) + '.jpg')
                        output_dep_gt = os.path.join(output_dir, 'gt_dep_ind_{}'.format(name) + '.jpg')
                        output_nor_gt = os.path.join(output_dir, 'gt_nor_ind_{}'.format(name) + '.jpg')
                        output_deform_dewarp_ori = os.path.join(output_dir,
                                                                'deform_dewarp_ori_ind_{}'.format(name) + '.jpg')
                        output_deform_dewarp_alb = os.path.join(output_dir,
                                                                'deform_dewarp_alb_ind_{}'.format(name) + '.jpg')

                        """pred"""
                        write_image_tensor(mask_pred, output_mask_pred, '01')
                        write_image_tensor(alb_pred, output_alb_pred, 'std')
                        write_image_tensor(cmap_pred, output_cmap_pred, 'gauss', mean=[0.1108, 0.3160, 0.2859],
                                           std=[0.7065, 0.6840, 0.7141])
                        write_image_tensor(dep_pred, output_dep_pred, 'gauss', mean=[0.5], std=[0.5])
                        write_image_tensor(nor_pred, output_nor_pred, 'gauss', mean=[0.5619, 0.2881, 0.2917],
                                           std=[0.5619, 0.7108, 0.7083])
                        write_image_tensor(deform_bw_pred, output_deform_bw_pred, 'std', device=device)
                        """gt"""
                        write_image_tensor(ori_gt, output_ori, 'std')
                        write_image_tensor(uv_gt, output_uv_gt, 'std', device=device)
                        write_image_tensor(mask_gt, output_mask_gt, '01')
                        write_image_tensor(alb_gt, output_alb_gt, 'std')
                        write_image_tensor(cmap_gt, output_cmap_gt, 'gauss', mean=[0.1108, 0.3160, 0.2859],
                                           std=[0.7065, 0.6840, 0.7141])
                        write_image_tensor(dep_gt, output_dep_gt, 'gauss', mean=[0.5], std=[0.5])
                        write_image_tensor(nor_gt, output_nor_gt, 'gauss', mean=[0.5619, 0.2881, 0.2917],
                                           std=[0.5619, 0.7108, 0.7083])
                        write_image_np(bw_gt, output_bw_gt)

                        write_image(dewarp_ori_gt, output_dewarp_ori_gt)
                        write_image(dewarp_alb_gt, output_dewarp_alb_gt)
                        """dewarp"""
                        write_image(deform_dewarp_ori, output_deform_dewarp_ori)
                        write_image(deform_dewarp_alb, output_deform_dewarp_alb)
            if (batch_idx + 1) % 20 == 0:
                print('It cost {} seconds to test {} images'.format(time.time() - start_time,
                                                                    (batch_idx + 1) * args.test_batch_size))
                start_time = time.time()
    test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_alb = loss_sum_alb / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_threeD = loss_sum_cmap / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_nor = loss_sum_nor / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_dep = loss_sum_dep / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_mask = loss_sum_mask / (len(test_loader.dataset) / args.test_batch_size)
    test_loss_smooth = loss_sum_smooth / (len(test_loader.dataset) / args.test_batch_size)
    test_dice = dice_sum_metric / (len(test_loader.dataset) / args.test_batch_size)
    print(
        'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.6f}||threeD:{:.6f}||nor:{:.6f}||dep:{:.6f}||mask:{:.6f}||'
        'smooth:{:.6f}||dice:{:.6f}'.format(
            epoch,
            batch_idx + 1,
            len(
                test_loader.dataset) // args.batch_size,
            test_loss.item(),
            test_loss_alb.item(),
            test_loss_threeD.item(),
            test_loss_nor.item(),
            test_loss_dep.item(),
            test_loss_mask.item(),
            test_loss_smooth.item(),
            test_dice.item(),
            ))
    if args.calculate_CC:
        num_iters = math.ceil(len(test_loader.dataset) / args.test_batch_size)
        metrics_alb = {key: metrics_alb[key] / num_iters for key in metrics_alb.keys()}
        metrics_cmap = {key: metrics_cmap[key] / num_iters for key in metrics_cmap.keys()}
        metrics_dep = {key: metrics_dep[key] / num_iters for key in metrics_dep.keys()}
        metrics_nor = {key: metrics_nor[key] / num_iters for key in metrics_nor.keys()}
        metrics_deform_bw = {key: metrics_deform_bw[key] / num_iters for key in metrics_deform_bw.keys()}
        metrics_deform_ori = {key: metrics_deform_ori[key] / num_iters for key in metrics_deform_ori.keys()}
        metrics_deform_alb = {key: metrics_deform_alb[key] / num_iters for key in metrics_deform_alb.keys()}
    if args.calculate_CC:
        for key in metrics_alb.keys():
            print(
                'Metric:{}\t dep:{:.6f}\t nor:{:.6f}\t cmap:{:.6f}\t alb:{:.6f}\t deform_bw:{:.6f}\t'
                ' deform_ori:{:.6f}\t deform_alb:{:.6f}\n'.format(
                    key, metrics_dep[key].item(), metrics_nor[key].item(),
                    metrics_cmap[key].item(), metrics_alb[key].item(),
                    metrics_deform_bw[key].item(), metrics_deform_ori[key].item(), metrics_deform_alb[key].item()))
    if args.write_txt:
        txt_dir = 'output_txt/' + args.model_name + '.txt'
        f = open(txt_dir, 'a')
        f.write(
            'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t normal: {:.4f}, \t depth: {:.4f}, \t mask: {:.6f} , '
            ' \t smooth:{:.6f},\t dice:{:.6f}\n'.format(epoch,
                                                                                                        test_loss.item(),
                                                                                                        test_loss_alb.item(),
                                                                                                        test_loss_threeD.item(),
                                                                                                        test_loss_nor.item(),
                                                                                                        test_loss_dep.item(),
                                                                                                        test_loss_mask.item(),
                                                                                                        test_loss_smooth.item(),
                                                                                                        test_dice.item()))
        if args.calculate_CC:
            for key in metrics_alb.keys():
                f.write(
                    'Metric:{}\t dep:{:.6f}\t nor:{:.6f}\t cmap:{:.6f}\t alb:{:.6f}\t deform_bw:{:.6f}\t'
                    ' deform_ori:{:.6f}\t deform_alb:{:.6f}\n'.format(key,
                                                                      metrics_dep[key].item(), metrics_nor[key].item(),
                                                                      metrics_cmap[key].item(), metrics_alb[key].item(),
                                                                      metrics_deform_bw[key].item(),
                                                                      metrics_deform_ori[key].item(),
                                                                      metrics_deform_alb[key].item()))
        f.close()
    if args.write_summary:
        print('sstep', sstep)
        # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1)
        writer.add_scalar('summary/test_loss', test_loss.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_ab', test_loss_alb.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_cmap', test_loss_threeD.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_normal', test_loss_nor.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_depth', test_loss_dep.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_back', test_loss_mask.item(), global_step=sstep)
        writer.add_scalar('summary/test_loss_smooth', test_loss_smooth.item(), global_step=sstep)