required=False, help='Width of output light probe.') opt = parser.parse_args() if opt.calib_fp[:2] == '_/': opt.calib_fp = os.path.join(opt.data_root, opt.calib_fp[2:]) if opt.obj_fp[:2] == '_/': opt.obj_fp = os.path.join(opt.data_root, opt.obj_fp[2:]) img_dir = os.path.join(opt.data_root, 'rgb' + str(opt.lighting_idx)) # save directories save_dir_lp = os.path.join(opt.data_root, 'light_probe_stitch_' + opt.sampling_pattern) save_dir_lp_mask = os.path.join(save_dir_lp, 'mask') save_dir_lp_count = os.path.join(save_dir_lp, 'count') data_util.cond_mkdir(save_dir_lp) data_util.cond_mkdir(save_dir_lp_mask) data_util.cond_mkdir(save_dir_lp_count) # load calibration and mesh calib = scipy.io.loadmat(opt.calib_fp) poses = calib['poses'] projs = calib['projs'] img_hws = calib['img_hws'] global_RT = calib['global_RT'] global_RT_inv = np.linalg.inv(global_RT) num_view = poses.shape[0] mesh = trimesh.load(opt.obj_fp, process=False) vertices = np.dot( global_RT,
def main(): # dataset loader for view data view_dataset = dataio.ViewDataset(root_dir = opt.data_root, img_dir = opt.img_dir, calib_path = opt.calib_fp, calib_format = opt.calib_format, img_size = [opt.img_size, opt.img_size], sampling_pattern = opt.sampling_pattern, ignore_dist_coeffs = True, load_precompute = False, ) print('Start buffering view data...') view_dataset.buffer_all() view_dataloader = DataLoader(view_dataset, batch_size = 1, shuffle = False, num_workers = 8) # set up save directories save_dir_raster = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'raster') if not opt.only_mesh_related: save_dir_pose = os.path.join(opt.data_root, 'precomp_' + obj_name, 'pose') save_dir_proj = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'proj') save_dir_img_gt = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'img_gt') save_dir_uv_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'uv_map') save_dir_uv_map_preview = os.path.join(save_dir_uv_map, 'preview') save_dir_alpha_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'alpha_map') save_dir_position_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'position_map') save_dir_position_map_cam = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'position_map_cam') save_dir_normal_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'normal_map') save_dir_normal_map_preview = os.path.join(save_dir_normal_map, 'preview') save_dir_normal_map_cam = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'normal_map_cam') save_dir_normal_map_cam_preview = os.path.join(save_dir_normal_map_cam, 'preview') save_dir_view_dir_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'view_dir_map') save_dir_view_dir_map_preview = os.path.join(save_dir_view_dir_map, 'preview') save_dir_view_dir_map_cam = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'view_dir_map_cam') save_dir_view_dir_map_cam_preview = os.path.join(save_dir_view_dir_map_cam, 'preview') save_dir_view_dir_map_tangent = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'view_dir_map_tangent') save_dir_view_dir_map_tangent_preview = os.path.join(save_dir_view_dir_map_tangent, 'preview') save_dir_sh_basis_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'sh_basis_map') save_dir_reflect_dir_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'reflect_dir_map') save_dir_reflect_dir_map_preview = os.path.join(save_dir_reflect_dir_map, 'preview') save_dir_reflect_dir_map_cam = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'reflect_dir_map_cam') save_dir_reflect_dir_map_cam_preview = os.path.join(save_dir_reflect_dir_map_cam, 'preview') save_dir_TBN_map = os.path.join(opt.data_root, 'precomp_' + obj_name, 'resol_' + str(opt.img_size), 'TBN_map') save_dir_TBN_map_preview = os.path.join(save_dir_TBN_map, 'preview') data_util.cond_mkdir(save_dir_raster) if not opt.only_mesh_related: data_util.cond_mkdir(save_dir_pose) data_util.cond_mkdir(save_dir_proj) data_util.cond_mkdir(save_dir_img_gt) data_util.cond_mkdir(save_dir_uv_map) data_util.cond_mkdir(save_dir_uv_map_preview) data_util.cond_mkdir(save_dir_alpha_map) data_util.cond_mkdir(save_dir_position_map) data_util.cond_mkdir(save_dir_position_map_cam) data_util.cond_mkdir(save_dir_normal_map) data_util.cond_mkdir(save_dir_normal_map_preview) data_util.cond_mkdir(save_dir_normal_map_cam) data_util.cond_mkdir(save_dir_normal_map_cam_preview) data_util.cond_mkdir(save_dir_view_dir_map) data_util.cond_mkdir(save_dir_view_dir_map_preview) data_util.cond_mkdir(save_dir_view_dir_map_cam) data_util.cond_mkdir(save_dir_view_dir_map_cam_preview) data_util.cond_mkdir(save_dir_view_dir_map_tangent) data_util.cond_mkdir(save_dir_view_dir_map_tangent_preview) data_util.cond_mkdir(save_dir_sh_basis_map) data_util.cond_mkdir(save_dir_reflect_dir_map) data_util.cond_mkdir(save_dir_reflect_dir_map_preview) data_util.cond_mkdir(save_dir_reflect_dir_map_cam) data_util.cond_mkdir(save_dir_reflect_dir_map_cam_preview) data_util.cond_mkdir(save_dir_TBN_map) data_util.cond_mkdir(save_dir_TBN_map_preview) print('Precompute view-related data...') for view_trgt in view_dataloader: img_gt = view_trgt[0]['img_gt'][0, :].permute((1, 2, 0)).cpu().detach().numpy() * 255.0 proj_orig = view_trgt[0]['proj_orig'].to(device) proj = view_trgt[0]['proj'].to(device) proj_inv = view_trgt[0]['proj_inv'].to(device) R_inv = view_trgt[0]['R_inv'].to(device) pose = view_trgt[0]['pose'].to(device) T = view_trgt[0]['pose'][:, :3, -1].to(device) img_fn = view_trgt[0]['img_fn'][0].split('.')[0] # rasterize uv_map, alpha_map, face_index_map, weight_map, faces_v_idx, normal_map, normal_map_cam, faces_v, faces_vt, position_map, position_map_cam, depth, v_uvz, v_front_mask = \ rasterizer(proj = proj, pose = pose, dist_coeffs = view_trgt[0]['dist_coeffs'].to(device), offset = None, scale = None, ) # save raster data scipy.io.savemat(os.path.join(save_dir_raster, img_fn + '.mat'), {'face_index_map': face_index_map[0, :].cpu().detach().numpy(), 'weight_map': weight_map[0, :].cpu().detach().numpy(), 'faces_v_idx': faces_v_idx[0, :].cpu().detach().numpy(), 'v_uvz': v_uvz[0, :].cpu().detach().numpy(), 'v_front_mask': v_front_mask[0, :].cpu().detach().numpy()}) if not opt.only_mesh_related: # save img_gt cv2.imwrite(os.path.join(save_dir_img_gt, img_fn + '.png'), img_gt[:, :, ::-1]) # compute TBN_map TBN_map = render.get_TBN_map(normal_map, face_index_map, faces_v = faces_v[0, :], faces_texcoord = faces_vt[0, :], tangent = None) # save TBN_map scipy.io.savemat(os.path.join(save_dir_TBN_map, img_fn + '.mat'), {'TBN_map': TBN_map[0, :].cpu().detach().numpy()}) # save preview cv2.imwrite(os.path.join(save_dir_TBN_map_preview, img_fn + '_0.png'), (TBN_map[0, ..., 0].cpu().detach().numpy()[:, :, ::-1] * 0.5 + 0.5) * 255) cv2.imwrite(os.path.join(save_dir_TBN_map_preview, img_fn + '_1.png'), (TBN_map[0, ..., 1].cpu().detach().numpy()[:, :, ::-1] * 0.5 + 0.5) * 255) cv2.imwrite(os.path.join(save_dir_TBN_map_preview, img_fn + '_2.png'), (TBN_map[0, ..., 2].cpu().detach().numpy()[:, :, ::-1] * 0.5 + 0.5) * 255) # removed padded regions alpha_map = alpha_map * torch.from_numpy(img_gt[:, :, 0] <= (2.0 * 255)).to(alpha_map.dtype).to(alpha_map.device) uv_map = uv_map.cpu().detach().numpy() alpha_map = alpha_map.cpu().detach().numpy() normal_map = normal_map.cpu().detach().numpy() normal_map_cam = normal_map_cam.cpu().detach().numpy() position_map = position_map.cpu().detach().numpy() position_map_cam = position_map_cam.cpu().detach().numpy() depth = depth.cpu().detach().numpy() # save pose, proj_orig scipy.io.savemat(os.path.join(save_dir_pose, img_fn + '.mat'), {'pose': pose[0, :].cpu().detach().numpy(), 'proj_orig': proj_orig[0, :].cpu().detach().numpy()}) # save proj scipy.io.savemat(os.path.join(save_dir_proj, img_fn + '.mat'), {'proj': proj[0, :].cpu().detach().numpy()}) # save uv_map scipy.io.savemat(os.path.join(save_dir_uv_map, img_fn + '.mat'), {'uv_map': uv_map[0, :]}) # save uv_map preview uv_map_img = np.concatenate((uv_map[0, :, :, :], np.zeros((*uv_map.shape[1:3], 1))), axis = 2) cv2.imwrite(os.path.join(save_dir_uv_map_preview, img_fn + '.png'), uv_map_img[:, :, ::-1] * 255) # save alpha_map cv2.imwrite(os.path.join(save_dir_alpha_map, img_fn + '.png'), alpha_map[0, :] * 255) # save normal_map scipy.io.savemat(os.path.join(save_dir_normal_map, img_fn + '.mat'), {'normal_map': normal_map[0, :]}) normal_map_cam = normal_map_cam * np.array([1, -1, -1], dtype = np.float32)[None, None, None, :] # change to z-out space scipy.io.savemat(os.path.join(save_dir_normal_map_cam, img_fn + '.mat'), {'normal_map_cam': normal_map_cam[0, :]}) # save normal_map preview normal_map_img = (normal_map[0, :, :, :] + 1.0) / 2 normal_map_cam_img = (normal_map_cam[0, :, :, :] + 1.0) / 2 cv2.imwrite(os.path.join(save_dir_normal_map_preview, img_fn + '.png'), normal_map_img[:, :, ::-1] * 255) cv2.imwrite(os.path.join(save_dir_normal_map_cam_preview, img_fn + '.png'), normal_map_cam_img[:, :, ::-1] * 255) # save position map scipy.io.savemat(os.path.join(save_dir_position_map, img_fn + '.mat'), {'position_map': position_map[0, :]}) scipy.io.savemat(os.path.join(save_dir_position_map_cam, img_fn + '.mat'), {'position_map_cam': position_map_cam[0, :]}) # compute view_dir_map view_dir_map, view_dir_map_cam = camera.get_view_dir_map(img_gt.shape[:2], proj_inv, R_inv) view_dir_map_cam = view_dir_map_cam.cpu().detach().numpy() view_dir_map_cam = view_dir_map_cam * np.array([1, -1, -1], dtype = np.float32)[None, None, None, :] # change to z-out space # save view_dir_map scipy.io.savemat(os.path.join(save_dir_view_dir_map, img_fn + '.mat'), {'view_dir_map': view_dir_map.cpu().detach().numpy()[0, :]}) scipy.io.savemat(os.path.join(save_dir_view_dir_map_cam, img_fn + '.mat'), {'view_dir_map_cam': view_dir_map_cam[0, :]}) # save view_dir_map preview view_dir_map_img = (view_dir_map.cpu().detach().numpy()[0, :, :, :] + 1.0) / 2 view_dir_map_cam_img = (view_dir_map_cam[0, :, :, :] + 1.0) / 2 cv2.imwrite(os.path.join(save_dir_view_dir_map_preview, img_fn + '.png'), view_dir_map_img[:, :, ::-1] * 255) cv2.imwrite(os.path.join(save_dir_view_dir_map_cam_preview, img_fn + '.png'), view_dir_map_cam_img[:, :, ::-1] * 255) # compute view_dir_map in tangent space view_dir_map_tangent = torch.matmul(TBN_map.reshape((-1, 3, 3)).transpose(-2, -1), view_dir_map.reshape((-1, 3, 1)))[..., 0].reshape(view_dir_map.shape) view_dir_map_tangent = torch.nn.functional.normalize(view_dir_map_tangent, dim = -1) # save view_dir_map_tangent scipy.io.savemat(os.path.join(save_dir_view_dir_map_tangent, img_fn + '.mat'), {'view_dir_map_tangent': view_dir_map_tangent.cpu().detach().numpy()[0, :]}) # save preview view_dir_map_tangent_img = (view_dir_map_tangent.cpu().detach().numpy()[0, :, :, :] + 1.0) / 2 cv2.imwrite(os.path.join(save_dir_view_dir_map_tangent_preview, img_fn + '.png'), view_dir_map_tangent_img[:, :, ::-1] * 255) # SH basis value for view_dir_map sh_basis_map = sph_harm.evaluate_sh_basis(lmax = 2, directions = view_dir_map.reshape((-1, 3)).cpu().detach().numpy()).reshape((*(view_dir_map.shape[:3]), -1)).astype(np.float32) # [N, H, W, 9] # save scipy.io.savemat(os.path.join(save_dir_sh_basis_map, img_fn + '.mat'), {'sh_basis_map': sh_basis_map[0, :]}) # compute reflect_dir_map reflect_dir_map = camera.get_reflect_dir(view_dir_map.to(device), torch.from_numpy(normal_map).to(device)).cpu().detach().numpy() * alpha_map[..., None] reflect_dir_map_cam = camera.get_reflect_dir(torch.from_numpy(view_dir_map_cam).to(device), torch.from_numpy(normal_map_cam).to(device)).cpu().detach().numpy() * alpha_map[..., None] # save reflect_dir_map scipy.io.savemat(os.path.join(save_dir_reflect_dir_map, img_fn + '.mat'), {'reflect_dir_map': reflect_dir_map[0, :]}) # [H, W, 3] scipy.io.savemat(os.path.join(save_dir_reflect_dir_map_cam, img_fn + '.mat'), {'reflect_dir_map_cam': reflect_dir_map_cam[0, :]}) # save reflect_dir_map preview reflect_dir_map_img = (reflect_dir_map[0, :, :, :] + 1.0) / 2 reflect_dir_map_cam_img = (reflect_dir_map_cam[0, :, :, :] + 1.0) / 2 cv2.imwrite(os.path.join(save_dir_reflect_dir_map_preview, img_fn + '.png'), reflect_dir_map_img[:, :, ::-1] * 255) cv2.imwrite(os.path.join(save_dir_reflect_dir_map_cam_preview, img_fn + '.png'), reflect_dir_map_cam_img[:, :, ::-1] * 255) idx = view_trgt[0]['idx'].cpu().detach().numpy().item() if not idx % 10: print('View', idx)
def main(): view_dataset.buffer_all() log_dir = opt.checkpoint_dir.split('/') log_dir = os.path.join( opt.calib_dir, 'resol_' + str(opt.img_size), log_dir[-2], log_dir[-1].split('_')[0] + '_' + log_dir[-1].split('_')[1] + '_' + opt.checkpoint_name.split('-')[-1].split('.')[0]) data_util.cond_mkdir(log_dir) save_dir_img_est = os.path.join(log_dir, 'img_est') data_util.cond_mkdir(save_dir_img_est) save_dir_alpha_map = os.path.join(log_dir, 'alpha_map') data_util.cond_mkdir(save_dir_alpha_map) save_dir_sh_basis_map = os.path.join(opt.calib_dir, 'resol_' + str(opt.img_size), 'precomp', 'sh_basis_map') data_util.cond_mkdir(save_dir_sh_basis_map) # Save all command line arguments into a txt file in the logging directory for later reference. with open(os.path.join(log_dir, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) print('Begin inference...') with torch.no_grad(): for ithView in range(num_view): start = time.time() # get view data view_trgt = view_dataset[ithView] proj = view_trgt[0]['proj'].to(device) pose = view_trgt[0]['pose'].to(device) proj_inv = view_trgt[0]['proj_inv'].to(device) R_inv = view_trgt[0]['R_inv'].to(device) proj = proj[None, :] pose = pose[None, :] proj_inv = proj_inv[None, :] R_inv = R_inv[None, :] # rasterize uv_map, alpha_map, face_index_map, weight_map, faces_v_idx, normal_map, normal_map_cam, faces_v, faces_vt, position_map, position_map_cam, depth, v_uvz, v_front_mask = \ rasterizer(proj = proj.cuda(0), pose = pose.cuda(0), dist_coeffs = None, offset = None, scale = None, ) # save alpha map cv2.imwrite( os.path.join(save_dir_alpha_map, str(ithView).zfill(5) + '.png'), alpha_map[0, :, :, None].cpu().detach().numpy()[:, :, ::-1] * 255.) sh_basis_map_fp = os.path.join(save_dir_sh_basis_map, str(ithView).zfill(5) + '.mat') if opt.force_recompute or not os.path.isfile(sh_basis_map_fp): print('Compute sh_basis_map...') # compute view_dir_map in world space view_dir_map, _ = camera.get_view_dir_map( uv_map.shape[1:3], proj_inv, R_inv) # SH basis value for view_dir_map sh_basis_map = sph_harm.evaluate_sh_basis( lmax=2, directions=view_dir_map.reshape( (-1, 3)).cpu().detach().numpy()).reshape( (*(view_dir_map.shape[:3]), -1)).astype(np.float32) # [N, H, W, 9] # save scipy.io.savemat(sh_basis_map_fp, {'sh_basis_map': sh_basis_map[0, :]}) else: sh_basis_map = scipy.io.loadmat( sh_basis_map_fp)['sh_basis_map'][None, ...] sh_basis_map = torch.from_numpy(sh_basis_map).to(device) # sample texture neural_img = texture_mapper(uv_map, sh_basis_map) # rendering net outputs_final = render_net(neural_img, None) img_max_val = 2.0 outputs_final = (outputs_final * 0.5 + 0.5) * img_max_val # map to [0, img_max_val] outputs_final = outputs_final * alpha_map[:, None, ...] # save cv2.imwrite( os.path.join(save_dir_img_est, str(ithView).zfill(5) + '.png'), outputs_final[0, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) end = time.time() print("View %07d t_total %0.4f" % (ithView, end - start))
def main(): print('Start buffering data for training views...') view_dataset.buffer_all() view_dataloader = DataLoader(view_dataset, batch_size=opt.batch_size, shuffle=True, num_workers=8) print('Start buffering data for validation views...') view_val_dataset.buffer_all() view_val_dataloader = DataLoader(view_val_dataset, batch_size=opt.batch_size, shuffle=False, num_workers=8) # directory name contains some info about hyperparameters. dir_name = os.path.join(datetime.datetime.now().strftime('%m-%d') + '_' + datetime.datetime.now().strftime('%H-%M-%S') + '_' + opt.sampling_pattern + '_' + opt.data_root.strip('/').split('/')[-1]) if opt.exp_name is not '': dir_name += '_' + opt.exp_name # directory for logging log_dir = os.path.join(opt.logging_root, dir_name) data_util.cond_mkdir(log_dir) # directory for saving validation data on view synthesis val_out_dir = os.path.join(log_dir, 'val_out') val_gt_dir = os.path.join(log_dir, 'val_gt') val_err_dir = os.path.join(log_dir, 'val_err') data_util.cond_mkdir(val_out_dir) data_util.cond_mkdir(val_gt_dir) data_util.cond_mkdir(val_err_dir) # Save all command line arguments into a txt file in the logging directory for later reference. with open(os.path.join(log_dir, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) writer = SummaryWriter(log_dir) iter = opt.start_epoch * len(view_dataset) print('Begin training...') val_log_batch_id = 0 first_val = True for epoch in range(opt.start_epoch, opt.max_epoch): for view_trgt in view_dataloader: start = time.time() # get view data uv_map = view_trgt[0]['uv_map'].to(device) # [N, H, W, 2] sh_basis_map = view_trgt[0]['sh_basis_map'].to( device) # [N, H, W, 9] alpha_map = view_trgt[0]['alpha_map'][:, None, :, :].to( device) # [N, 1, H, W] img_gt = [] for i in range(len(view_trgt)): img_gt.append(view_trgt[i]['img_gt'].to(device)) # sample texture neural_img = texture_mapper(uv_map, sh_basis_map) # rendering net outputs = render_net(neural_img, None) img_max_val = 2.0 outputs = (outputs * 0.5 + 0.5) * img_max_val # map to [0, img_max_val] if type(outputs) is not list: outputs = [outputs] # We don't enforce a loss on the outermost 5 pixels to alleviate boundary errors, also weight loss by alpha alpha_map_central = alpha_map[:, :, 5:-5, 5:-5] for i in range(len(view_trgt)): outputs[i] = outputs[i][:, :, 5:-5, 5:-5] * alpha_map_central img_gt[i] = img_gt[i][:, :, 5:-5, 5:-5] * alpha_map_central # loss on final image loss_rn = list() for idx in range(len(view_trgt)): loss_rn.append( criterionL1(outputs[idx].contiguous().view(-1).float(), img_gt[idx].contiguous().view(-1).float())) loss_rn = torch.stack(loss_rn, dim=0).mean() # total loss loss_g = loss_rn optimizerG.zero_grad() loss_g.backward() optimizerG.step() # error metrics with torch.no_grad(): err_metrics_batch_i = metric.compute_err_metrics_batch( outputs[0] * 255.0, img_gt[0] * 255.0, alpha_map_central, compute_ssim=False) # tensorboard scalar logs of training data writer.add_scalar("loss_g", loss_g, iter) writer.add_scalar("loss_rn", loss_rn, iter) writer.add_scalar("final_mae_valid", err_metrics_batch_i['mae_valid_mean'], iter) writer.add_scalar("final_psnr_valid", err_metrics_batch_i['psnr_valid_mean'], iter) end = time.time() print( "Iter %07d Epoch %03d loss_g %0.4f mae_valid %0.4f psnr_valid %0.4f t_total %0.4f" % (iter, epoch, loss_g, err_metrics_batch_i['mae_valid_mean'], err_metrics_batch_i['psnr_valid_mean'], end - start)) # tensorboard figure logs of training data if not iter % opt.log_freq: output_final_vs_gt = [] for i in range(len(view_trgt)): output_final_vs_gt.append(outputs[i].clamp(min=0., max=1.)) output_final_vs_gt.append(img_gt[i].clamp(min=0., max=1.)) output_final_vs_gt.append( (outputs[i] - img_gt[i]).abs().clamp(min=0., max=1.)) output_final_vs_gt = torch.cat(output_final_vs_gt, dim=0) writer.add_image( "output_final_vs_gt", torchvision.utils.make_grid( output_final_vs_gt, nrow=outputs[0].shape[0], range=(0, 1), scale_each=False, normalize=False).cpu().detach().numpy(), iter) # validation if not iter % opt.val_freq: start_val = time.time() with torch.no_grad(): # error metrics err_metrics_val = {} err_metrics_val['mae_valid'] = [] err_metrics_val['mse_valid'] = [] err_metrics_val['psnr_valid'] = [] err_metrics_val['ssim_valid'] = [] # loop over batches batch_id = 0 for view_val_trgt in view_val_dataloader: start_val_i = time.time() # get view data uv_map = view_val_trgt[0]['uv_map'].to( device) # [N, H, W, 2] sh_basis_map = view_val_trgt[0]['sh_basis_map'].to( device) # [N, H, W, 9] alpha_map = view_val_trgt[0][ 'alpha_map'][:, None, :, :].to(device) # [N, 1, H, W] view_idx = view_val_trgt[0]['idx'] batch_size = alpha_map.shape[0] img_h = alpha_map.shape[2] img_w = alpha_map.shape[3] num_view = len(view_val_trgt) img_gt = [] for i in range(num_view): img_gt.append( view_val_trgt[i]['img_gt'].to(device)) # sample texture neural_img = texture_mapper(uv_map, sh_basis_map) # rendering net outputs = render_net(neural_img, None) img_max_val = 2.0 outputs = (outputs * 0.5 + 0.5 ) * img_max_val # map to [0, img_max_val] if type(outputs) is not list: outputs = [outputs] # apply alpha for i in range(num_view): outputs[i] = outputs[i] * alpha_map img_gt[i] = img_gt[i] * alpha_map # tensorboard figure logs of validation data if batch_id == val_log_batch_id: output_final_vs_gt = [] for i in range(num_view): output_final_vs_gt.append(outputs[i].clamp( min=0., max=1.)) output_final_vs_gt.append(img_gt[i].clamp( min=0., max=1.)) output_final_vs_gt.append( (outputs[i] - img_gt[i]).abs().clamp( min=0., max=1.)) output_final_vs_gt = torch.cat(output_final_vs_gt, dim=0) writer.add_image( "output_final_vs_gt_val", torchvision.utils.make_grid( output_final_vs_gt, nrow=batch_size, range=(0, 1), scale_each=False, normalize=False).cpu().detach().numpy(), iter) # error metrics err_metrics_batch_i_final = metric.compute_err_metrics_batch( outputs[0] * 255.0, img_gt[0] * 255.0, alpha_map, compute_ssim=True) for i in range(batch_size): for key in list(err_metrics_val.keys()): if key in err_metrics_batch_i_final.keys(): err_metrics_val[key].append( err_metrics_batch_i_final[key][i]) # save images for i in range(batch_size): cv2.imwrite( os.path.join( val_out_dir, str(iter).zfill(8) + '_' + str(view_idx[i].cpu().detach().numpy( )).zfill(5) + '.png'), outputs[0][i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) cv2.imwrite( os.path.join( val_err_dir, str(iter).zfill(8) + '_' + str(view_idx[i].cpu().detach().numpy( )).zfill(5) + '.png'), (outputs[0] - img_gt[0]).abs().clamp( min=0., max=1.)[i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) if first_val: cv2.imwrite( os.path.join( val_gt_dir, str(view_idx[i].cpu().detach().numpy() ).zfill(5) + '.png'), img_gt[0][i, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) end_val_i = time.time() print( "Val batch %03d mae_valid %0.4f psnr_valid %0.4f ssim_valid %0.4f t_total %0.4f" % (batch_id, err_metrics_batch_i_final['mae_valid_mean'], err_metrics_batch_i_final['psnr_valid_mean'], err_metrics_batch_i_final['ssim_valid_mean'], end_val_i - start_val_i)) batch_id += 1 for key in list(err_metrics_val.keys()): if err_metrics_val[key]: err_metrics_val[key] = np.vstack( err_metrics_val[key]) err_metrics_val[ key + '_mean'] = err_metrics_val[key].mean() else: err_metrics_val[key + '_mean'] = np.nan # tensorboard scalar logs of validation data writer.add_scalar("final_mae_valid_val", err_metrics_val['mae_valid_mean'], iter) writer.add_scalar("final_psnr_valid_val", err_metrics_val['psnr_valid_mean'], iter) writer.add_scalar("final_ssim_valid_val", err_metrics_val['ssim_valid_mean'], iter) first_val = False val_log_batch_id = (val_log_batch_id + 1) % batch_id end_val = time.time() print( "Val mae_valid %0.4f psnr_valid %0.4f ssim_valid %0.4f t_total %0.4f" % (err_metrics_val['mae_valid_mean'], err_metrics_val['psnr_valid_mean'], err_metrics_val['ssim_valid_mean'], end_val - start_val)) iter += 1 if iter % opt.ckp_freq == 0: util.custom_save( os.path.join(log_dir, 'model_epoch-%d_iter-%s.pth' % (epoch, iter)), part_list, part_name_list) util.custom_save( os.path.join(log_dir, 'model_epoch-%d_iter-%s.pth' % (epoch, iter)), part_list, part_name_list)
def main(): view_dataset.buffer_all() if opt.lighting_type == 'train': lighting_idx_all = [int(params['lighting_idx'])] else: lighting_idx_all = [opt.lighting_idx] log_dir = opt.checkpoint_dir.split('/') log_dir = os.path.join( opt.calib_dir, 'resol_' + str(opt.img_size), log_dir[-2], log_dir[-1].split('_')[0] + '_' + log_dir[-1].split('_')[1] + '_' + opt.checkpoint_name.split('-')[-1].split('.')[0]) data_util.cond_mkdir(log_dir) # get estimated illumination lp_est = lighting_model_train.to(device) lp_est = lp_est(lighting_idx=int(params['lighting_idx']), is_lp=True) cv2.imwrite(log_dir + '/lp_est.png', lp_est.cpu().detach().numpy()[0, :, :, ::-1] * 255.0) save_dir_alpha_map = os.path.join(log_dir, 'alpha_map') data_util.cond_mkdir(save_dir_alpha_map) save_dir_sh_basis_map = os.path.join(opt.calib_dir, 'resol_' + str(opt.img_size), 'precomp', 'sh_basis_map') data_util.cond_mkdir(save_dir_sh_basis_map) # Save all command line arguments into a txt file in the logging directory for later reference. with open(os.path.join(log_dir, "params.txt"), "w") as out_file: out_file.write('\n'.join( ["%s: %s" % (key, value) for key, value in vars(opt).items()])) print('Begin inference...') with torch.no_grad(): for ithView in range(num_view): t_prep = time.time() # get view data view_trgt = view_dataset[ithView] proj = view_trgt[0]['proj'].to(device) pose = view_trgt[0]['pose'].to(device) proj_inv = view_trgt[0]['proj_inv'].to(device) R_inv = view_trgt[0]['R_inv'].to(device) proj = proj[None, :] pose = pose[None, :] proj_inv = proj_inv[None, :] R_inv = R_inv[None, :] t_prep = time.time() - t_prep t_raster = time.time() # rasterize uv_map, alpha_map, face_index_map, weight_map, faces_v_idx, normal_map, normal_map_cam, faces_v, faces_vt, position_map, position_map_cam, depth, v_uvz, v_front_mask = \ rasterizer(proj = proj.cuda(0), pose = pose.cuda(0), dist_coeffs = None, offset = None, scale = None, ) uv_map = uv_map.to(device) alpha_map = alpha_map.to(device) face_index_map = face_index_map.to(device) normal_map = normal_map.to(device) faces_v = faces_v.to(device) faces_vt = faces_vt.to(device) t_raster = time.time() - t_raster # save alpha map cv2.imwrite( os.path.join(save_dir_alpha_map, str(ithView).zfill(5) + '.png'), alpha_map[0, :, :, None].cpu().detach().numpy()[:, :, ::-1] * 255.) t_preproc = time.time() batch_size = alpha_map.shape[0] img_h = alpha_map.shape[1] img_w = alpha_map.shape[2] # compute TBN_map TBN_map = render.get_TBN_map(normal_map, face_index_map, faces_v=faces_v[0, :], faces_texcoord=faces_vt[0, :], tangent=None) # compute view_dir_map in world space view_dir_map, _ = camera.get_view_dir_map(uv_map.shape[1:3], proj_inv, R_inv) # compute view_dir_map in tangent space view_dir_map_tangent = torch.matmul( TBN_map.reshape((-1, 3, 3)).transpose(-2, -1), view_dir_map.reshape( (-1, 3, 1)))[..., 0].reshape(view_dir_map.shape) view_dir_map_tangent = torch.nn.functional.normalize( view_dir_map_tangent, dim=-1) t_preproc = time.time() - t_preproc t_sh = time.time() # SH basis value for view_dir_map sh_basis_map_fp = os.path.join(save_dir_sh_basis_map, str(ithView).zfill(5) + '.mat') if opt.force_recompute or not os.path.isfile(sh_basis_map_fp): print('Compute sh_basis_map...') sh_basis_map = sph_harm.evaluate_sh_basis( lmax=2, directions=view_dir_map.reshape( (-1, 3)).cpu().detach().numpy()).reshape( (*(view_dir_map.shape[:3]), -1)).astype(np.float32) # [N, H, W, 9] # save scipy.io.savemat(sh_basis_map_fp, {'sh_basis_map': sh_basis_map[0, :]}) else: sh_basis_map = scipy.io.loadmat( sh_basis_map_fp)['sh_basis_map'][None, ...] sh_basis_map = torch.from_numpy(sh_basis_map).to(device) t_sh = time.time() - t_sh t_network = time.time() # sample texture neural_img = texture_mapper(uv_map, sh_basis_map, sh_start_ch=6) # [N, C, H, W] albedo_diffuse = neural_img[:, :3, :, :] albedo_specular = neural_img[:, 3:6, :, :] # sample specular rays rays_dir, rays_uv, rays_dir_tangent = ray_sampler( TBN_map, view_dir_map_tangent, alpha_map[..., None] ) # [N, H, W, 3, num_ray], [N, H, W, 2, num_ray], [N, H, W, 3, num_ray] num_ray = rays_uv.shape[-1] # sample diffuse rays rays_diffuse_dir, rays_diffuse_uv, _ = ray_sampler_diffuse( TBN_map, view_dir_map_tangent, alpha_map[ ..., None]) # [N, H, W, 3, num_ray], [N, H, W, 2, num_ray] num_ray_diffuse = rays_diffuse_uv.shape[-1] num_ray_total = num_ray + num_ray_diffuse # concat data rays_dir = torch.cat((rays_dir, rays_diffuse_dir), dim=-1) rays_uv = torch.cat((rays_uv, rays_diffuse_uv), dim=-1) # estimate light transport for rays render_net_input = torch.cat( (rays_dir.permute((0, -1, -2, 1, 2)).reshape( (batch_size, -1, img_h, img_w)), normal_map.permute( (0, 3, 1, 2)), view_dir_map.permute( (0, 3, 1, 2)), neural_img), dim=1) rays_lt = render_net(render_net_input, v_feature).reshape( (batch_size, num_ray_total, -1, img_h, img_w)) # [N, num_ray, C, H, W] lt_max_val = 2.0 rays_lt = (rays_lt * 0.5 + 0.5) * lt_max_val # map to [0, lt_max_val] t_network = time.time() - t_network for lighting_idx in lighting_idx_all: print('Lighting', lighting_idx) save_dir_img_est = os.path.join( log_dir, 'img_est_' + opt.lighting_type + '_' + str(lighting_idx).zfill(3)) data_util.cond_mkdir(save_dir_img_est) # render using ray_renderer t_render = time.time() outputs_final, _, _, _, _, _, lp = ray_renderer( albedo_specular, rays_uv, rays_lt, lighting_idx=lighting_idx, albedo_diffuse=albedo_diffuse, num_ray_diffuse=num_ray_diffuse, lp_scale_factor=1, seperate_albedo=True) t_render = time.time() - t_render print('View:', ithView, t_prep, t_raster, t_preproc, t_sh, t_network, t_render) # save rendered image cv2.imwrite( os.path.join(save_dir_img_est, str(ithView).zfill(5) + '.png'), outputs_final[0, :].permute( (1, 2, 0)).cpu().detach().numpy()[:, :, ::-1] * 255.) # get background image if opt.save_img_bg: save_dir_img_bg = os.path.join( log_dir, 'img_bg_' + opt.lighting_type + '_' + str(lighting_idx).zfill(3)) data_util.cond_mkdir(save_dir_img_bg) # get view uv on light probe view_uv_map = render.spherical_mapping_batch( -view_dir_map.transpose(1, -1)).transpose( 1, -1) # [N, H, W, 2] lp_sh = lp lp_lp = lighting_model_lp(lighting_idx, is_lp=True).to( device) # [N, H, W, C] img_bg_sh = interpolater( lp_sh, (view_uv_map[..., 0] * float(lp_sh.shape[2])).clamp(max=lp_sh.shape[2] - 1), (view_uv_map[..., 1] * float(lp_sh.shape[1])).clamp( max=lp_sh.shape[1] - 1)) # [N, H, W, C] img_bg_lp = interpolater( lp_lp, (view_uv_map[..., 0] * float(lp_lp.shape[2])).clamp(max=lp_lp.shape[2] - 1), (view_uv_map[..., 1] * float(lp_lp.shape[1])).clamp( max=lp_lp.shape[1] - 1)) # [N, H, W, C] cv2.imwrite( os.path.join(save_dir_img_bg, 'sh_' + str(ithView).zfill(5) + '.png'), img_bg_sh[0, :].cpu().detach().numpy()[:, :, ::-1] * 255.) cv2.imwrite( os.path.join(save_dir_img_bg, 'lp_' + str(ithView).zfill(5) + '.png'), img_bg_lp[0, :].cpu().detach().numpy()[:, :, ::-1] * 255.)