def test(args, model, device, test_loader, criterion, epoch, writer, output_dir, isWriteImage, sstep): print('Testing') model.eval() test_loss = 0 correct = 0 metrics_uv = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_cmap = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_alb = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_dep = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_nor = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_bw = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_deori = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} metrics_dealb = {'l1_norm': 0, 'mse': 0, 'pearsonr_metric': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0} with torch.no_grad(): loss_sum = 0 loss_sum_alb = 0 loss_sum_threeD = 0 loss_sum_uv = 0 loss_sum_nor = 0 loss_sum_dep = 0 loss_sum_mask = 0 cons_sum_t2n = 0 cons_sum_t2d = 0 cons_sum_n2d = 0 cons_sum_d2n = 0 loss_sum_tv = 0 tv_critic = TVLoss().to(device) start_time = time.time() for batch_idx, data in enumerate(test_loader): rgb = data[0] alb_map_gt = data[1] dep_map_gt = data[2] nor_map_gt = data[3] threeD_map_gt = data[4] uv_map_gt = data[5] mask_map_gt = data[6] rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, threeD_map_gt, mask_map_gt = rgb.to( device), alb_map_gt.to(device), dep_map_gt.to(device), nor_map_gt.to(device), uv_map_gt.to( device), threeD_map_gt.to(device), mask_map_gt.to(device) uv_map, threeD_map, nor_map, alb_map, dep_map, mask_map, nor_from_threeD, dep_from_threeD, nor_from_dep, dep_from_nor = model( rgb) bc_critic = nn.BCELoss() loss_mask = bc_critic(mask_map, mask_map_gt).float() loss_tv = tv_critic(mask_map).float() loss_threeD = criterion(threeD_map, threeD_map_gt).float() loss_uv = criterion(uv_map, uv_map_gt).float() loss_dep = criterion(dep_map, dep_map_gt).float() loss_nor = criterion(nor_map, nor_map_gt).float() loss_alb = criterion(alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)).float() if nor_from_threeD is not None: cons_threeD2nor = criterion(nor_from_threeD, nor_map_gt).float() else: cons_threeD2nor = torch.Tensor([0.]).to(device) if dep_from_threeD is not None: cons_threeD2dep = criterion(dep_from_threeD, dep_map_gt).float() else: cons_threeD2dep = torch.Tensor([0.]).to(device) if nor_from_dep is not None: cons_dep2nor = criterion(nor_from_dep, nor_map_gt).float() else: cons_dep2nor = torch.Tensor([0.]).to(device) if dep_from_nor is not None: cons_nor2dep = criterion(dep_from_nor, dep_map_gt).float() else: cons_nor2dep = torch.Tensor([0.]).to(device) test_loss = 4 * loss_uv + 4 * loss_alb + 4 * loss_nor + loss_dep + 2 * loss_mask + loss_threeD + \ cons_threeD2nor + cons_threeD2dep + cons_dep2nor + cons_nor2dep + loss_tv loss_sum = loss_sum + test_loss loss_sum_alb += loss_alb loss_sum_threeD += loss_threeD loss_sum_uv += loss_uv loss_sum_nor += loss_nor loss_sum_dep += loss_dep loss_sum_mask += loss_mask cons_sum_t2n += cons_threeD2nor cons_sum_t2d += cons_threeD2dep cons_sum_n2d += cons_nor2dep cons_sum_d2n += cons_dep2nor loss_sum_tv += loss_tv if args.calculate_CC: alb_map_gt = torch.unsqueeze(alb_map_gt[:, 0, :, :], 1) alb_map_gt_recover = re_normalize(alb_map_gt, mean=0.5, std=0.5, inplace=False) alb_map_recover = re_normalize(alb_map, mean=0.5, std=0.5, inplace=False) metric_op = metrics.film_metrics().to(device) metric_alb = metric_op(alb_map_recover, alb_map_gt_recover) uv_map_recover = re_normalize(uv_map, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False) uv_map_gt_recover = re_normalize(uv_map_gt, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False) metric_uv = metric_op(uv_map_recover, uv_map_gt_recover) threeD_map_recover = re_normalize(threeD_map, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141], inplace=False) threeD_map_gt_recover = re_normalize(threeD_map_gt, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141], inplace=False) metric_cmap = metric_op(threeD_map_recover, threeD_map_gt_recover) nor_map_recover = re_normalize(nor_map, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083], inplace=False) nor_map_gt_recover = re_normalize(nor_map_gt, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083], inplace=False) metric_nor = metric_op(nor_map_recover, nor_map_gt_recover) dep_map_recover = re_normalize(dep_map, mean=0.5, std=0.5, inplace=False) dep_map_gt_recover = re_normalize(dep_map_gt, mean=0.5, std=0.5, inplace=False) metric_dep = metric_op(dep_map_recover, dep_map_gt_recover) bw_pred = metrics.uv2bmap4d(uv_map, mask_map) bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt) bw_pred = torch.from_numpy(bw_pred).to(device) bw_gt = torch.from_numpy(bw_gt).to(device) metric_bw = metric_op(bw_pred, bw_gt) dewarp_ori = metrics.bw_mapping4d(bw_pred, rgb, device) dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device) metric_deori = metric_op(dewarp_ori, dewarp_ori_gt) dewarp_ab = metrics.bw_mapping4d(bw_pred, alb_map, device) dewarp_ab_gt = metrics.bw_mapping4d(bw_gt, alb_map_gt, device) metric_dealb = metric_op(torch.unsqueeze(dewarp_ab, 0), torch.unsqueeze(dewarp_ab_gt, 0)) metrics_alb = {key: metrics_alb[key] + metric_alb[key] for key in metrics_alb.keys()} metrics_uv = {key: metrics_uv[key] + metric_uv[key] for key in metrics_uv.keys()} metrics_cmap = {key: metrics_cmap[key] + metric_cmap[key] for key in metrics_cmap.keys()} metrics_dep = {key: metrics_dep[key] + metric_dep[key] for key in metrics_dep.keys()} metrics_nor = {key: metrics_nor[key] + metric_nor[key] for key in metrics_nor.keys()} metrics_bw = {key: metrics_bw[key] + metric_bw[key] for key in metrics_bw.keys()} metrics_deori = {key: metrics_deori[key] + metric_deori[key] for key in metrics_deori.keys()} metrics_dealb = {key: metrics_dealb[key] + metric_dealb[key] for key in metrics_dealb.keys()} if isWriteImage: if batch_idx == (len(test_loader.dataset) // args.test_batch_size) - 1: if not os.path.exists(output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)): os.makedirs(output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)) print('writing test image') # for k in range(args.test_batch_size): for k in range(5): albedo_pred = alb_map[k, :, :, :] uv_pred = uv_map[k, :, :, :] back_pred = mask_map[k, :, :, :] back_pred = torch.round(back_pred) cmap_pred = threeD_map[k, :, :, :] depth_pred = dep_map[k, :, :, :] normal_pred = nor_map[k, :, :, :] ori_gt = rgb[k, :, :, :] ab_gt = alb_map_gt[k, :, :, :] uv_gt = uv_map_gt[k, :, :, :] mask_gt = mask_map_gt[k, :, :, :] cmap_gt = threeD_map_gt[k, :, :, :] depth_gt = dep_map_gt[k, :, :, :] normal_gt = nor_map_gt[k, :, :, :] bw_gt = metrics.uv2bmap(uv_gt, mask_gt) bw_pred = metrics.uv2bmap(uv_pred, back_pred) # [-1,1], [256, 256, 3] dewarp_ori = metrics.bw_mapping(bw_pred, ori_gt, device) dewarp_ab = metrics.bw_mapping(bw_pred, ab_gt, device) dewarp_ori_gt = metrics.bw_mapping(bw_gt, ori_gt, device) output_dir1 = output_dir + 'test/epoch_{}_batch_{}/'.format(epoch, batch_idx) output_uv_pred = output_dir1 + 'pred_uv_ind_{}'.format(k) + '.jpg' output_back_pred = output_dir1 + 'pred_back_ind_{}'.format(k) + '.jpg' output_ab_pred = output_dir1 + 'pred_ab_ind_{}'.format(k) + '.jpg' output_3d_pred = output_dir1 + 'pred_3D_ind_{}'.format(k) + '.jpg' output_bw_pred = output_dir1 + 'pred_bw_ind_{}'.format(k) + '.jpg' output_depth_pred = output_dir1 + 'pred_depth_ind_{}'.format(k) + '.jpg' output_normal_pred = output_dir1 + 'pred_normal_ind_{}'.format(k) + '.jpg' output_ori = output_dir1 + 'gt_ori_ind_{}'.format(k) + '.jpg' output_uv_gt = output_dir1 + 'gt_uv_ind_{}'.format(k) + '.jpg' output_ab_gt = output_dir1 + 'gt_ab_ind_{}'.format(k) + '.jpg' output_cmap_gt = output_dir1 + 'gt_cmap_ind_{}'.format(k) + '.jpg' output_back_gt = output_dir1 + 'gt_back_ind_{}'.format(k) + '.jpg' output_bw_gt = output_dir1 + 'gt_bw_ind_{}'.format(k) + '.jpg' output_dewarp_ori_gt = output_dir1 + 'gt_dewarpOri_ind_{}'.format(k) + '.jpg' output_depth_gt = output_dir1 + 'gt_depth_ind_{}'.format(k) + '.jpg' output_normal_gt = output_dir1 + 'gt_normal_ind_{}'.format(k) + '.jpg' output_dewarp_ori = output_dir1 + 'dewarp_ori_ind_{}'.format(k) + '.jpg' output_dewarp_ab = output_dir1 + 'dewarp_ab_ind_{}'.format(k) + '.jpg' """pred""" write_image_tensor(uv_pred, output_uv_pred, 'std', device=device) write_image_tensor(back_pred, output_back_pred, '01') write_image_tensor(albedo_pred, output_ab_pred, 'std') write_image_tensor(cmap_pred, output_3d_pred, 'gauss', mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141]) write_image_tensor(depth_pred, output_depth_pred, 'gauss', mean=[0.5], std=[0.5]) write_image_tensor(normal_pred, output_normal_pred, 'gauss', mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083]) write_image_np(bw_pred, output_bw_pred) """gt""" write_image_tensor(ori_gt, output_ori, 'std') write_image_tensor(uv_gt, output_uv_gt, 'std', device=device) write_image_tensor(mask_gt, output_back_gt, '01') write_image_tensor(ab_gt, output_ab_gt, 'std') write_image_tensor(cmap_gt, output_cmap_gt, 'gauss', mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141]) write_image_tensor(depth_gt, output_depth_gt, 'gauss', mean=[0.5], std=[0.5]) write_image_tensor(normal_gt, output_normal_gt, 'gauss', mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083]) write_image_np(bw_gt, output_bw_gt) write_image(dewarp_ori_gt, output_dewarp_ori_gt) """dewarp""" write_image(dewarp_ori, output_dewarp_ori) write_image(dewarp_ab, output_dewarp_ab) if (batch_idx + 1) % 20 == 0: print('It cost {} seconds to test {} images'.format(time.time() - start_time, (batch_idx + 1) * args.test_batch_size)) start_time = time.time() test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size) test_loss_alb = loss_sum_alb / (len(test_loader.dataset) / args.test_batch_size) test_loss_threeD = loss_sum_threeD / (len(test_loader.dataset) / args.test_batch_size) test_loss_uv = loss_sum_uv / (len(test_loader.dataset) / args.test_batch_size) test_loss_nor = loss_sum_nor / (len(test_loader.dataset) / args.test_batch_size) test_loss_dep = loss_sum_dep / (len(test_loader.dataset) / args.test_batch_size) test_loss_mask = loss_sum_mask / (len(test_loader.dataset) / args.test_batch_size) test_cons_t2n = cons_sum_t2n / (len(test_loader.dataset) / args.test_batch_size) test_cons_t2d = cons_sum_t2d / (len(test_loader.dataset) / args.test_batch_size) test_cons_n2d = cons_sum_n2d / (len(test_loader.dataset) / args.test_batch_size) test_cons_d2n = cons_sum_d2n / (len(test_loader.dataset) / args.test_batch_size) test_loss_tv = loss_tv / (len(test_loader.dataset) / args.test_batch_size) print( 'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.4f}||threeD:{:.4f}||uv:{:.6f}||nor:{:.4f}||dep:{:.4f}||mask:{:.6f}||cons_t2n:{:6f}' 'cons_t2d:{:.6f}||cons_n2d:{:.6f}||cons_d2n:{:.6f}||loss_tv:{:.6f}'.format(epoch, batch_idx + 1, len( test_loader.dataset) // args.batch_size, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_cons_t2n.item(), test_cons_t2d.item(), test_cons_n2d.item(), test_cons_d2n.item(), test_loss_tv.item())) if args.calculate_CC: num_iters = math.ceil(len(test_loader.dataset) / args.test_batch_size) metrics_alb = {key: metrics_alb[key] / num_iters for key in metrics_alb.keys()} metrics_uv = {key: metrics_uv[key] / num_iters for key in metrics_uv.keys()} metrics_cmap = {key: metrics_cmap[key] / num_iters for key in metrics_cmap.keys()} metrics_dep = {key: metrics_dep[key] / num_iters for key in metrics_dep.keys()} metrics_nor = {key: metrics_nor[key] / num_iters for key in metrics_nor.keys()} metrics_bw = {key: metrics_bw[key] / num_iters for key in metrics_bw.keys()} metrics_deori = {key: metrics_deori[key] / num_iters for key in metrics_deori.keys()} metrics_dealb = {key: metrics_dealb[key] / num_iters for key in metrics_dealb.keys()} if args.calculate_CC: for key in metrics_alb.keys(): print(str(key) + '_uv:{:.6f}\t' + str(key) + '_dep:{:.6f}\t' + str(key) + '_nor:{:.6f}\t' + str(key) + '_cmap:{:.6f}\t' + str(key) + '_alb:{:.6f}\t' + str(key) + '_bw:{:.6f}\t' + str( key) + '_deori:{:.6f}\t' + str(key) + '_dealb:{:.6f}'.format(metrics_uv[key], metrics_dep[key], metrics_nor[key], metrics_cmap[key], metrics_alb[key], metrics_bw[key], metrics_deori[key], metrics_dealb[key])) if args.write_txt: txt_dir = 'output_txt/' + args.model_name + '.txt' f = open(txt_dir, 'a') f.write( 'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t uv: {:.6f}, \t normal: {:.4f}, \t depth: {:.4f}, \t back: {:.6f} , \t constrain 3d to normal: {:.4f}, \t constrain 3d to depth: {:.4f}, \t constrain normal to depth: {:.4f}, \t constrain depth to normal: {:.4f}, \t loss tv: {:.6f}\n'.format( epoch, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_cons_t2n.item(), test_cons_t2d.item(), test_cons_n2d.item(), test_cons_d2n.item(), test_loss_tv.item())) for key in metrics_alb.keys(): f.write(str(key) + '_uv:{:.6f}\t' + str(key) + '_dep:{:.6f}\t' + str(key) + '_nor:{:.6f}\t' + str(key) + '_cmap:{:.6f}\t' + str(key) + '_alb:{:.6f}\t' + str(key) + '_bw:{:.6f}\t' + str( key) + '_deori:{:.6f}\t' + str(key) + '_dealb:{:.6f}\n'.format(metrics_uv[key], metrics_dep[key], metrics_nor[key], metrics_cmap[key], metrics_alb[key], metrics_bw[key], metrics_deori[key], metrics_dealb[key])) f.close() if args.write_summary: print('sstep', sstep) # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1) writer.add_scalar('summary/test_loss', test_loss.item(), global_step=sstep) writer.add_scalar('summary/test_loss_ab', test_loss_alb.item(), global_step=sstep) writer.add_scalar('summary/test_loss_cmap', test_loss_threeD.item(), global_step=sstep) writer.add_scalar('summary/test_loss_uv', test_loss_uv.item(), global_step=sstep) writer.add_scalar('summary/test_loss_normal', test_loss_nor.item(), global_step=sstep) writer.add_scalar('summary/test_loss_depth', test_loss_dep.item(), global_step=sstep) writer.add_scalar('summary/test_loss_back', test_loss_mask.item(), global_step=sstep) writer.add_scalar('summary/test_con_3d2nor', test_cons_t2n.item(), global_step=sstep) writer.add_scalar('summary/test_con_3d2dep', test_cons_t2d.item(), global_step=sstep) writer.add_scalar('summary/test_con_nor2dep', test_cons_n2d.item(), global_step=sstep) writer.add_scalar('summary/test_loss_dep2nor', test_cons_d2n.item(), global_step=sstep) writer.add_scalar('summary/test_loss_tv', test_loss_tv.item(), global_step=sstep)
def test(args, model, device, test_loader, criterion, epoch, writer, output_dir, isWriteImage, sstep): print('Testing') model.eval() test_loss = 0 correct = 0 cc_uv = 0 cc_threeD = 0 cc_alb = 0 cc_bw = 0 cc_dewarp_ori = 0 cc_dewarp_alb = 0 with torch.no_grad(): loss_sum = 0 loss_sum_alb = 0 loss_sum_threeD = 0 loss_sum_uv = 0 loss_sum_nor = 0 loss_sum_dep = 0 loss_sum_mask = 0 cons_sum_t2n = 0 cons_sum_t2d = 0 cons_sum_n2d = 0 cons_sum_d2n = 0 start_time = time.time() for batch_idx, data in enumerate(test_loader): rgb = data[0] alb_map_gt = data[1] dep_map_gt = data[2] nor_map_gt = data[3] threeD_map_gt = data[4] uv_map_gt = data[5] mask_map_gt = data[6] rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, threeD_map_gt, mask_map_gt = rgb.to( device), alb_map_gt.to(device), dep_map_gt.to( device), nor_map_gt.to(device), uv_map_gt.to( device), threeD_map_gt.to(device), mask_map_gt.to( device) uv_map, threeD_map, nor_map, alb_map, dep_map, mask_map, nor_from_threeD, dep_from_threeD, nor_from_dep, dep_from_nor = model( rgb) loss_mask = criterion(mask_map, mask_map_gt).float() loss_threeD = criterion(threeD_map, threeD_map_gt).float() loss_uv = criterion(uv_map, uv_map_gt).float() loss_dep = criterion(dep_map, dep_map_gt).float() loss_nor = criterion(nor_map, nor_map_gt).float() loss_alb = criterion(alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)).float() if nor_from_threeD is not None: cons_threeD2nor = criterion(nor_from_threeD, nor_map_gt).float() else: cons_threeD2nor = torch.Tensor([0.]).to(device) if dep_from_threeD is not None: cons_threeD2dep = criterion(dep_from_threeD, dep_map_gt).float() else: cons_threeD2dep = torch.Tensor([0.]).to(device) if nor_from_dep is not None: cons_dep2nor = criterion(nor_from_dep, nor_map_gt).float() else: cons_dep2nor = torch.Tensor([0.]).to(device) if dep_from_nor is not None: cons_nor2dep = criterion(dep_from_nor, dep_map_gt).float() else: cons_nor2dep = torch.Tensor([0.]).to(device) test_loss = 4 * loss_uv + 4 * loss_alb + loss_nor + loss_dep + 2 * loss_mask + loss_threeD + \ cons_threeD2nor + cons_threeD2dep + cons_dep2nor + cons_nor2dep loss_sum = loss_sum + test_loss loss_sum_alb += loss_alb loss_sum_threeD += loss_threeD loss_sum_uv += loss_uv loss_sum_nor += loss_nor loss_sum_dep += loss_dep loss_sum_mask += loss_mask cons_sum_t2n += cons_threeD2nor cons_sum_t2d += cons_threeD2dep cons_sum_n2d += cons_nor2dep cons_sum_d2n += cons_dep2nor if args.calculate_CC: c_alb = metrics.calculate_CC_metrics( alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)) c_uv = metrics.calculate_CC_metrics(uv_map, uv_map_gt) c_threeD = metrics.calculate_CC_metrics( threeD_map, threeD_map_gt) bw_pred = metrics.uv2bmap4d(uv_map, mask_map) bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt) c_bw = metrics.calculate_CC_metrics(bw_pred, bw_gt) dewarp_ori = metrics.bw_mapping4d(bw_pred, rgb, device) dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device) c_dewarp_ori = metrics.calculate_CC_metrics( dewarp_ori, dewarp_ori_gt) dewarp_ab = metrics.bw_mapping4d(bw_pred, alb_map, device) dewarp_ab_gt = metrics.bw_mapping4d( bw_gt, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1), device) c_dewarp_alb = metrics.calculate_CC_metrics( torch.unsqueeze(dewarp_ab, 0), torch.unsqueeze(dewarp_ab_gt, 0)) cc_alb += c_alb cc_uv += c_uv cc_threeD += c_threeD cc_bw += c_bw cc_dewarp_ori += c_dewarp_ori cc_dewarp_alb += c_dewarp_alb if isWriteImage: if batch_idx == (len(test_loader.dataset) // args.test_batch_size) - 1: if not os.path.exists( output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)): os.makedirs( output_dir + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)) print('writing test image') for k in range(args.test_batch_size): # print('k', k) albedo_pred = alb_map[k, :, :, :] uv_pred = uv_map[k, :, :, :] back_pred = mask_map[k, :, :, :] cmap_pred = threeD_map[k, :, :, :] depth_pred = dep_map[k, :, :, :] normal_pred = nor_map[k, :, :, :] ori_gt = rgb[k, :, :, :] ab_gt = alb_map_gt[k, :, :, :] uv_gt = uv_map_gt[k, :, :, :] mask_gt = mask_map_gt[k, :, :, :] cmap_gt = threeD_map_gt[k, :, :, :] depth_gt = dep_map_gt[k, :, :, :] normal_gt = nor_map_gt[k, :, :, :] bw_gt = metrics.uv2bmap(uv_gt, mask_gt) bw_pred = metrics.uv2bmap( uv_pred, back_pred) # [-1,1], [256, 256, 3] # bw_gt = bmap[k, :, :, :] dewarp_ori = metrics.bw_mapping( bw_pred, ori_gt, device) dewarp_ab = metrics.bw_mapping(bw_pred, ab_gt, device) dewarp_ori_gt = metrics.bw_mapping( bw_gt, ori_gt, device) output_dir1 = output_dir + 'test/epoch_{}_batch_{}/'.format( epoch, batch_idx) output_uv_pred = output_dir1 + 'pred_uv_ind_{}'.format( k) + '.jpg' output_back_pred = output_dir1 + 'pred_back_ind_{}'.format( k) + '.jpg' output_ab_pred = output_dir1 + 'pred_ab_ind_{}'.format( k) + '.jpg' output_3d_pred = output_dir1 + 'pred_3D_ind_{}'.format( k) + '.jpg' output_bw_pred = output_dir1 + 'pred_bw_ind_{}'.format( k) + '.jpg' output_depth_pred = output_dir1 + 'pred_depth_ind_{}'.format( k) + '.jpg' output_normal_pred = output_dir1 + 'pred_normal_ind_{}'.format( k) + '.jpg' output_ori = output_dir1 + 'gt_ori_ind_{}'.format( k) + '.jpg' output_uv_gt = output_dir1 + 'gt_uv_ind_{}'.format( k) + '.jpg' output_ab_gt = output_dir1 + 'gt_ab_ind_{}'.format( k) + '.jpg' output_cmap_gt = output_dir1 + 'gt_cmap_ind_{}'.format( k) + '.jpg' output_back_gt = output_dir1 + 'gt_back_ind_{}'.format( k) + '.jpg' output_bw_gt = output_dir1 + 'gt_bw_ind_{}'.format( k) + '.jpg' output_dewarp_ori_gt = output_dir1 + 'gt_dewarpOri_ind_{}'.format( k) + '.jpg' output_depth_gt = output_dir1 + 'gt_depth_ind_{}'.format( k) + '.jpg' output_normal_gt = output_dir1 + 'gt_normal_ind_{}'.format( k) + '.jpg' output_dewarp_ori = output_dir1 + 'dewarp_ori_ind_{}'.format( k) + '.jpg' output_dewarp_ab = output_dir1 + 'dewarp_ab_ind_{}'.format( k) + '.jpg' """pred""" write_image_tensor(uv_pred, output_uv_pred, 'std', device=device) write_image_tensor(back_pred, output_back_pred, '01') write_image_tensor(albedo_pred, output_ab_pred, 'std') write_image_tensor(cmap_pred, output_3d_pred, 'gauss', mean=[0.100, 0.326, 0.289], std=[0.096, 0.332, 0.298]) write_image_tensor(depth_pred, output_depth_pred, 'gauss', mean=[0.316], std=[0.309]) write_image_tensor(normal_pred, output_normal_pred, 'gauss', mean=[0.584, 0.294, 0.300], std=[0.483, 0.251, 0.256]) write_image_np(bw_pred, output_bw_pred) """gt""" write_image_tensor(ori_gt, output_ori, 'std') write_image_tensor(uv_gt, output_uv_gt, 'std', device=device) write_image_tensor(mask_gt, output_back_gt, '01') write_image_tensor(ab_gt, output_ab_gt, 'std') write_image_tensor(cmap_gt, output_cmap_gt, 'gauss', mean=[0.100, 0.326, 0.289], std=[0.096, 0.332, 0.298]) write_image_tensor(depth_gt, output_depth_gt, 'gauss', mean=[0.316], std=[0.309]) write_image_tensor(normal_gt, output_normal_gt, 'gauss', mean=[0.584, 0.294, 0.300], std=[0.483, 0.251, 0.256]) write_image_np(bw_gt, output_bw_gt) write_image(dewarp_ori_gt, output_dewarp_ori_gt) """dewarp""" write_image(dewarp_ori, output_dewarp_ori) write_image(dewarp_ab, output_dewarp_ab) if (batch_idx + 1) % 20 == 0: print('It cost {} seconds to test {} images'.format( time.time() - start_time, (batch_idx + 1) * args.test_batch_size)) start_time = time.time() test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size) test_loss_alb = loss_sum_alb / (len(test_loader.dataset) / args.test_batch_size) test_loss_threeD = loss_sum_threeD / (len(test_loader.dataset) / args.test_batch_size) test_loss_uv = loss_sum_uv / (len(test_loader.dataset) / args.test_batch_size) test_loss_nor = loss_sum_nor / (len(test_loader.dataset) / args.test_batch_size) test_loss_dep = loss_sum_dep / (len(test_loader.dataset) / args.test_batch_size) test_loss_mask = loss_sum_mask / (len(test_loader.dataset) / args.test_batch_size) test_cons_t2n = cons_sum_t2n / (len(test_loader.dataset) / args.test_batch_size) test_cons_t2d = cons_sum_t2d / (len(test_loader.dataset) / args.test_batch_size) test_cons_n2d = cons_sum_n2d / (len(test_loader.dataset) / args.test_batch_size) test_cons_d2n = cons_sum_d2n / (len(test_loader.dataset) / args.test_batch_size) print( 'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.4f}||threeD:{:.4f}||uv:{:.6f}||nor:{:.4f}||dep:{:.4f}||mask:{:.6f}||cons_t2n:{:6f}' 'cons_t2d:{:.6f}||cons_n2d:{:.6f}||cons_d2n:{:.6f}'.format( epoch, batch_idx + 1, len(test_loader.dataset) // args.batch_size, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_cons_t2n.item(), test_cons_t2d.item(), test_cons_n2d.item(), test_cons_d2n.item())) if args.calculate_CC: cc_uv = cc_uv / (len(test_loader.dataset) / args.test_batch_size) cc_cmap = cc_threeD / (len(test_loader.dataset) / args.test_batch_size) cc_ab = cc_alb / (len(test_loader.dataset) / args.test_batch_size) cc_bw = cc_bw / (len(test_loader.dataset) / args.test_batch_size) cc_dewarp_ori = cc_dewarp_ori / (len(test_loader.dataset) / args.test_batch_size) cc_dewarp_ab = cc_dewarp_alb / (len(test_loader.dataset) / args.test_batch_size) if args.calculate_CC: print( 'CC_uv: {}\t CC_cmap: {}\t CC_ab: {}\t CC_bw: {}\t CC_dewarp_ori: {}\t CC_dewarp_ab: {}' .format(cc_uv, cc_cmap, cc_ab, cc_bw, cc_dewarp_ori, cc_dewarp_ab)) if args.write_txt: txt_dir = 'output_txt/' + args.model_name + '.txt' f = open(txt_dir, 'a') f.write( 'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t uv: {:.6f}, \t normal: {:.4f}, \t depth: {:.4f}, \t back: {:.6f} , \t constrain 3d to normal: {:.4f}, \t constrain 3d to depth: {:.4f}, \t constrain normal to depth: {:.4f}, \t constrain depth to normal: {:.4f}, CC_uv: {}\t CC_cmap: {}\t CC_ab: {}\t CC_bw: {}\t CC_dewarp_ori: {}\t CC_dewarp_ab: {}\n' .format(epoch, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_uv.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_cons_t2n.item(), test_cons_t2d.item(), test_cons_n2d.item(), test_cons_d2n.item(), cc_uv, cc_cmap, cc_ab, cc_bw, cc_dewarp_ori, cc_dewarp_ab)) f.close() if args.write_summary: print('sstep', sstep) # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1) writer.add_scalar('summary/test_loss', test_loss.item(), global_step=sstep) writer.add_scalar('summary/test_loss_ab', test_loss_alb.item(), global_step=sstep) writer.add_scalar('summary/test_loss_cmap', test_loss_threeD.item(), global_step=sstep) writer.add_scalar('summary/test_loss_uv', test_loss_uv.item(), global_step=sstep) writer.add_scalar('summary/test_loss_normal', test_loss_nor.item(), global_step=sstep) writer.add_scalar('summary/test_loss_depth', test_loss_dep.item(), global_step=sstep) writer.add_scalar('summary/test_loss_back', test_loss_mask.item(), global_step=sstep) writer.add_scalar('summary/test_con_3d2nor', test_cons_t2n.item(), global_step=sstep) writer.add_scalar('summary/test_con_3d2dep', test_cons_t2d.item(), global_step=sstep) writer.add_scalar('summary/test_con_nor2dep', test_cons_n2d.item(), global_step=sstep) writer.add_scalar('summary/test_loss_dep2nor', test_cons_d2n.item(), global_step=sstep)
def test(args, model, device, test_loader, criterion, epoch, writer, output_dir_test, isWriteImage, sstep): print('Testing') model.eval() metrics_cmap = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_alb = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_dep = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_nor = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_deform_bw = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_deform_ori = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} metrics_deform_alb = {'l1_norm': 0, 'mse': 0, 'cc': 0, 'psnr': 0, 'ssim': 0, 'mssim': 0, 'ncc': 0} with torch.no_grad(): loss_sum = 0 loss_sum_alb = 0 loss_sum_cmap = 0 loss_sum_nor = 0 loss_sum_dep = 0 loss_sum_mask = 0 loss_sum_smooth = 0 dice_sum_metric = 0 bc_critic = nn.BCELoss() start_time = time.time() for batch_idx, data in enumerate(test_loader): rgb = data[0] alb_map_gt = data[1] dep_map_gt = data[2] nor_map_gt = data[3] cmap_map_gt = data[4] uv_map_gt = data[5] mask_map_gt = data[6] bw_map_gt = data[7] deform_map_gt = data[8] names = data[9] rgb, alb_map_gt, dep_map_gt, nor_map_gt, uv_map_gt, cmap_map_gt, mask_map_gt, bw_map_gt, deform_map_gt = \ rgb.to(device), alb_map_gt.to(device), dep_map_gt.to(device), nor_map_gt.to(device), uv_map_gt.to( device), \ cmap_map_gt.to(device), mask_map_gt.to(device), bw_map_gt.to(device), deform_map_gt.to(device) cmap, nor_map, alb_map, dep_map, mask_map, deform_map = model(rgb) #bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt) #dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device) dewarp_ori_gt = metrics.bw_mapping4d(bw_map_gt, rgb, device) ###################################################################################################################### # TEST LOSS SUMMARY ###################################################################################################################### loss_mask = bc_critic(mask_map, mask_map_gt).float() dice_metric = diceCoeff(mask_map, mask_map_gt).float() loss_cmap = criterion(cmap, cmap_map_gt).float() loss_dep = criterion(dep_map, dep_map_gt).float() loss_nor = criterion(nor_map, nor_map_gt).float() loss_alb = criterion(alb_map, torch.unsqueeze(alb_map_gt[:, 0, :, :], 1)).float() loss_smooth = criterion(deform_map, deform_map_gt) test_loss = 4 * loss_alb + 4 * loss_nor + loss_dep + 2 * loss_mask + loss_cmap + \ loss_smooth loss_sum = loss_sum + test_loss loss_sum_alb += loss_alb loss_sum_cmap += loss_cmap loss_sum_nor += loss_nor loss_sum_dep += loss_dep loss_sum_mask += loss_mask loss_sum_smooth += loss_smooth dice_sum_metric += dice_metric alb_map_gt = torch.unsqueeze(alb_map_gt[:, 0, :, :], 1) if args.calculate_CC: metric_op = metrics.film_metrics_ncc().to(device) # ALB alb_map_gt_recover = torch.clamp(re_normalize(alb_map_gt, mean=0.5, std=0.5, inplace=False), 0., 1.) alb_map_recover = torch.clamp(re_normalize(alb_map, mean=0.5, std=0.5, inplace=False), 0., 1.) metric_alb = metric_op(alb_map_recover, alb_map_gt_recover) # UV uv_map_gt_recover = torch.clamp(re_normalize(uv_map_gt, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False), 0., 1.) # CMAP cmap_recover = torch.clamp( re_normalize(cmap, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141], inplace=False), 0., 1.) cmap_gt_recover = torch.clamp( re_normalize(cmap_map_gt, mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141], inplace=False), 0., 1.) metric_cmap = metric_op(cmap_recover, cmap_gt_recover) # Normal nor_map_recover = torch.clamp( re_normalize(nor_map, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083], inplace=False), 0., 1.) nor_map_gt_recover = torch.clamp( re_normalize(nor_map_gt, mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083], inplace=False), 0., 1.) metric_nor = metric_op(nor_map_recover, nor_map_gt_recover) # Depth dep_map_recover = torch.clamp(re_normalize(dep_map, mean=0.5, std=0.5, inplace=False), 0., 1.) dep_map_gt_recover = torch.clamp(re_normalize(dep_map_gt, mean=0.5, std=0.5, inplace=False), 0., 1.) metric_dep = metric_op(dep_map_recover, dep_map_gt_recover) # UV2BW bw_gt = metrics.uv2bmap4d(uv_map_gt, mask_map_gt) bw_gt_recover = torch.clamp( re_normalize(torch.from_numpy(np.transpose(bw_gt, (0, 3, 1, 2))).to(device), mean=[0.5, 0.5], std=[0.5, 0.5]), 0., 1.) # DEFORM deform_bw_pred = deform2bw_tensor_batch(deform_map.clone(), device) deform_bw_pred_recover = torch.clamp( re_normalize(deform_bw_pred, mean=[0.5, 0.5], std=[0.5, 0.5], inplace=False), 0., 1.) metric_deform_bw = metric_op(deform_bw_pred_recover.float().contiguous(), bw_gt_recover.float().contiguous()) # UV2BW Dewarp ORI dewarp_ori_gt = metrics.bw_mapping4d(bw_gt, rgb, device) dewarp_ori_gt_recover = torch.clamp( re_normalize(dewarp_ori_gt.transpose(2, 3).transpose(1, 2), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False), 0., 1.) # UV2BW Dewarp ALB dewarp_alb_gt = metrics.bw_mapping4d(bw_gt, alb_map_gt, device) dewarp_alb_gt_recover = torch.clamp(re_normalize(dewarp_alb_gt, mean=[0.5], std=[0.5], inplace=False), 0., 1.) # DEFORM DEWARP ORI deform_dewarp_ori = metrics.bw_mapping4d(deform_bw_pred, rgb, device) deform_dewarp_ori_recover = torch.clamp( re_normalize(deform_dewarp_ori.transpose(2, 3).transpose(1, 2), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=False), 0., 1.) metric_deform_ori = metric_op(deform_dewarp_ori_recover.contiguous(), dewarp_ori_gt_recover.contiguous()) # DEFORM DEWARP ALB deform_dewarp_alb = metrics.bw_mapping4d(deform_bw_pred, alb_map, device) deform_dewarp_alb_recover = torch.clamp( re_normalize(deform_dewarp_alb, mean=[0.5], std=[0.5], inplace=False), 0., 1.) metric_deform_alb = metric_op(torch.unsqueeze(deform_dewarp_alb_recover, 1).contiguous(), torch.unsqueeze(dewarp_alb_gt_recover, 1).contiguous()) # SUMMARY metrics_alb = {key: metrics_alb[key] + metric_alb[key] for key in metrics_alb.keys()} metrics_cmap = {key: metrics_cmap[key] + metric_cmap[key] for key in metrics_cmap.keys()} metrics_dep = {key: metrics_dep[key] + metric_dep[key] for key in metrics_dep.keys()} metrics_nor = {key: metrics_nor[key] + metric_nor[key] for key in metrics_nor.keys()} metrics_deform_bw = {key: metrics_deform_bw[key] + metric_deform_bw[key] for key in metrics_deform_bw.keys()} metrics_deform_ori = {key: metrics_deform_ori[key] + metric_deform_ori[key] for key in metrics_deform_ori.keys()} metrics_deform_alb = {key: metrics_deform_alb[key] + metric_deform_alb[key] for key in metrics_deform_alb.keys()} if isWriteImage: if batch_idx == (len(test_loader.dataset) // args.test_batch_size) - 2: if not os.path.exists(output_dir_test + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)): os.makedirs(output_dir_test + 'test/epoch_{}_batch_{}'.format(epoch, batch_idx)) print('writing test image') for k in range(args.test_batch_size): alb_pred = alb_map[k, :, :, :] mask_pred = mask_map[k, :, :, :] mask_pred = torch.round(mask_pred) cmap_pred = cmap[k, :, :, :] dep_pred = dep_map[k, :, :, :] nor_pred = nor_map[k, :, :, :] deform_bw_map = deform2bw_tensor_batch(deform_map.clone().to(device), device) deform_bw_pred = deform_bw_map[k, :, :, :] name = names[k] ori_gt = rgb[k, :, :, :] alb_gt = alb_map_gt[k, :, :, :] uv_gt = uv_map_gt[k, :, :, :] mask_gt = mask_map_gt[k, :, :, :] cmap_gt = cmap_map_gt[k, :, :, :] dep_gt = dep_map_gt[k, :, :, :] nor_gt = nor_map_gt[k, :, :, :] bw_gt = metrics.uv2bmap(uv_gt, mask_gt) dewarp_ori_gt = metrics.bw_mapping(bw_gt, ori_gt, device) dewarp_alb_gt = metrics.bw_mapping(bw_gt, alb_gt, device) deform_dewarp_ori = metrics.bw_mapping(deform_bw_pred, ori_gt, device) deform_dewarp_alb = metrics.bw_mapping(deform_bw_pred, alb_pred, device) output_dir = os.path.join(output_dir_test, 'test/epoch_{}_batch_{}/'.format(epoch, batch_idx)) output_mask_pred = os.path.join(output_dir, 'pred_mask_ind_{}'.format(name) + '.jpg') output_alb_pred = os.path.join(output_dir, 'pred_alb_ind_{}'.format(name) + '.jpg') output_cmap_pred = os.path.join(output_dir, 'pred_3D_ind_{}'.format(name) + '.jpg') output_deform_bw_pred = os.path.join(output_dir, 'pred_deform_bw_ind_{}'.format(name) + '.exr') output_dep_pred = os.path.join(output_dir, 'pred_dep_ind_{}'.format(name) + '.jpg') output_nor_pred = os.path.join(output_dir, 'pred_normal_ind_{}'.format(name) + '.jpg') output_ori = os.path.join(output_dir, 'gt_ori_ind_{}'.format(name) + '.jpg') output_uv_gt = os.path.join(output_dir, 'gt_uv_ind_{}'.format(name) + '.exr') output_alb_gt = os.path.join(output_dir, 'gt_alb_ind_{}'.format(name) + '.jpg') output_cmap_gt = os.path.join(output_dir, 'gt_cmap_ind_{}'.format(name) + '.jpg') output_mask_gt = os.path.join(output_dir, 'gt_mask_ind_{}'.format(name) + '.jpg') output_bw_gt = os.path.join(output_dir, 'gt_bw_ind_{}'.format(name) + '.exr') output_dewarp_ori_gt = os.path.join(output_dir, 'gt_dewarpOri_ind_{}'.format(name) + '.jpg') output_dewarp_alb_gt = os.path.join(output_dir, 'gt_dewarpAlb_ind_{}'.format(name) + '.jpg') output_dep_gt = os.path.join(output_dir, 'gt_dep_ind_{}'.format(name) + '.jpg') output_nor_gt = os.path.join(output_dir, 'gt_nor_ind_{}'.format(name) + '.jpg') output_deform_dewarp_ori = os.path.join(output_dir, 'deform_dewarp_ori_ind_{}'.format(name) + '.jpg') output_deform_dewarp_alb = os.path.join(output_dir, 'deform_dewarp_alb_ind_{}'.format(name) + '.jpg') """pred""" write_image_tensor(mask_pred, output_mask_pred, '01') write_image_tensor(alb_pred, output_alb_pred, 'std') write_image_tensor(cmap_pred, output_cmap_pred, 'gauss', mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141]) write_image_tensor(dep_pred, output_dep_pred, 'gauss', mean=[0.5], std=[0.5]) write_image_tensor(nor_pred, output_nor_pred, 'gauss', mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083]) write_image_tensor(deform_bw_pred, output_deform_bw_pred, 'std', device=device) """gt""" write_image_tensor(ori_gt, output_ori, 'std') write_image_tensor(uv_gt, output_uv_gt, 'std', device=device) write_image_tensor(mask_gt, output_mask_gt, '01') write_image_tensor(alb_gt, output_alb_gt, 'std') write_image_tensor(cmap_gt, output_cmap_gt, 'gauss', mean=[0.1108, 0.3160, 0.2859], std=[0.7065, 0.6840, 0.7141]) write_image_tensor(dep_gt, output_dep_gt, 'gauss', mean=[0.5], std=[0.5]) write_image_tensor(nor_gt, output_nor_gt, 'gauss', mean=[0.5619, 0.2881, 0.2917], std=[0.5619, 0.7108, 0.7083]) write_image_np(bw_gt, output_bw_gt) write_image(dewarp_ori_gt, output_dewarp_ori_gt) write_image(dewarp_alb_gt, output_dewarp_alb_gt) """dewarp""" write_image(deform_dewarp_ori, output_deform_dewarp_ori) write_image(deform_dewarp_alb, output_deform_dewarp_alb) if (batch_idx + 1) % 20 == 0: print('It cost {} seconds to test {} images'.format(time.time() - start_time, (batch_idx + 1) * args.test_batch_size)) start_time = time.time() test_loss = loss_sum / (len(test_loader.dataset) / args.test_batch_size) test_loss_alb = loss_sum_alb / (len(test_loader.dataset) / args.test_batch_size) test_loss_threeD = loss_sum_cmap / (len(test_loader.dataset) / args.test_batch_size) test_loss_nor = loss_sum_nor / (len(test_loader.dataset) / args.test_batch_size) test_loss_dep = loss_sum_dep / (len(test_loader.dataset) / args.test_batch_size) test_loss_mask = loss_sum_mask / (len(test_loader.dataset) / args.test_batch_size) test_loss_smooth = loss_sum_smooth / (len(test_loader.dataset) / args.test_batch_size) test_dice = dice_sum_metric / (len(test_loader.dataset) / args.test_batch_size) print( 'Epoch:{} \n batch index:{}/{}||loss:{:.6f}||alb:{:.6f}||threeD:{:.6f}||nor:{:.6f}||dep:{:.6f}||mask:{:.6f}||' 'smooth:{:.6f}||dice:{:.6f}'.format( epoch, batch_idx + 1, len( test_loader.dataset) // args.batch_size, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_loss_smooth.item(), test_dice.item(), )) if args.calculate_CC: num_iters = math.ceil(len(test_loader.dataset) / args.test_batch_size) metrics_alb = {key: metrics_alb[key] / num_iters for key in metrics_alb.keys()} metrics_cmap = {key: metrics_cmap[key] / num_iters for key in metrics_cmap.keys()} metrics_dep = {key: metrics_dep[key] / num_iters for key in metrics_dep.keys()} metrics_nor = {key: metrics_nor[key] / num_iters for key in metrics_nor.keys()} metrics_deform_bw = {key: metrics_deform_bw[key] / num_iters for key in metrics_deform_bw.keys()} metrics_deform_ori = {key: metrics_deform_ori[key] / num_iters for key in metrics_deform_ori.keys()} metrics_deform_alb = {key: metrics_deform_alb[key] / num_iters for key in metrics_deform_alb.keys()} if args.calculate_CC: for key in metrics_alb.keys(): print( 'Metric:{}\t dep:{:.6f}\t nor:{:.6f}\t cmap:{:.6f}\t alb:{:.6f}\t deform_bw:{:.6f}\t' ' deform_ori:{:.6f}\t deform_alb:{:.6f}\n'.format( key, metrics_dep[key].item(), metrics_nor[key].item(), metrics_cmap[key].item(), metrics_alb[key].item(), metrics_deform_bw[key].item(), metrics_deform_ori[key].item(), metrics_deform_alb[key].item())) if args.write_txt: txt_dir = 'output_txt/' + args.model_name + '.txt' f = open(txt_dir, 'a') f.write( 'Epoch: {} \t Test Loss: {:.6f}, \t ab: {:.4f}, \t cmap: {:.4f}, \t normal: {:.4f}, \t depth: {:.4f}, \t mask: {:.6f} , ' ' \t smooth:{:.6f},\t dice:{:.6f}\n'.format(epoch, test_loss.item(), test_loss_alb.item(), test_loss_threeD.item(), test_loss_nor.item(), test_loss_dep.item(), test_loss_mask.item(), test_loss_smooth.item(), test_dice.item())) if args.calculate_CC: for key in metrics_alb.keys(): f.write( 'Metric:{}\t dep:{:.6f}\t nor:{:.6f}\t cmap:{:.6f}\t alb:{:.6f}\t deform_bw:{:.6f}\t' ' deform_ori:{:.6f}\t deform_alb:{:.6f}\n'.format(key, metrics_dep[key].item(), metrics_nor[key].item(), metrics_cmap[key].item(), metrics_alb[key].item(), metrics_deform_bw[key].item(), metrics_deform_ori[key].item(), metrics_deform_alb[key].item())) f.close() if args.write_summary: print('sstep', sstep) # writer.add_scalar('test_acc', 100. * correct / len(test_loader.dataset), global_step=epoch+1) writer.add_scalar('summary/test_loss', test_loss.item(), global_step=sstep) writer.add_scalar('summary/test_loss_ab', test_loss_alb.item(), global_step=sstep) writer.add_scalar('summary/test_loss_cmap', test_loss_threeD.item(), global_step=sstep) writer.add_scalar('summary/test_loss_normal', test_loss_nor.item(), global_step=sstep) writer.add_scalar('summary/test_loss_depth', test_loss_dep.item(), global_step=sstep) writer.add_scalar('summary/test_loss_back', test_loss_mask.item(), global_step=sstep) writer.add_scalar('summary/test_loss_smooth', test_loss_smooth.item(), global_step=sstep)