def training_step(self, batch, batch_nb): rays, rgbs_target = self.decode_batch(batch) if args.N_importance and 0 == self.global_step % 500: self.update_density_volume() xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher( rays, N_samples=args.N_samples, N_importance=args.N_importance, lindisp=args.use_disp, perturb=args.perturb, density_volume=self.density_volume, bbox_3D=self.bbox_3d) # Converting world coordinate to ndc coordinate xyz_NDC = (xyz_coarse_sampled - self.bbox_3d[0].view(1, 1, 3)) / ( self.bbox_3d[1] - self.bbox_3d[0]).view(1, 1, 3) # rendering rgbs, disp, acc, depth_pred, alpha, extras = rendering( args, self.pose_source_ref, xyz_coarse_sampled, xyz_NDC, z_vals, rays_o, rays_d, self.volume, **self.render_kwargs_train) log, loss = {}, 0 if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], rgbs_target) loss = loss + img_loss0 psnr0 = mse2psnr2(img_loss0.item()) self.log('train/PSNR0', psnr0.item(), prog_bar=True) ################## rendering ##################### if self.args.with_rgb_loss: img_loss = img2mse(rgbs, rgbs_target) loss += img_loss psnr = mse2psnr2(img_loss.item()) with torch.no_grad(): self.log('train/loss', loss, prog_bar=True) self.log('train/img_mse_loss', img_loss.item(), prog_bar=False) self.log('train/PSNR', psnr.item(), prog_bar=True) # if self.global_step == 3999 or self.global_step == 9999: # self.save_ckpt(f'{self.global_step}') return {'loss': loss}
def validation_step(self, batch, batch_nb): self.MVSNet.train() rays, img = self.decode_batch(batch) img = img.cpu() # (H, W, 3) # mask = batch['mask'][0] N_rays_all = rays.shape[0] ################## rendering ##################### keys = ['val_psnr_all'] log = init_log({}, keys) with torch.no_grad(): rgbs, depth_preds = [], [] for chunk_idx in range(N_rays_all // args.chunk + int(N_rays_all % args.chunk > 0)): xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher( rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk], N_samples=args.N_samples, lindisp=args.use_disp) # Converting world coordinate to ndc coordinate H, W = img.shape[:2] inv_scale = torch.tensor([W - 1, H - 1]).to(device) w2c_ref, intrinsic_ref = self.pose_source['w2cs'][ 0], self.pose_source['intrinsics'][0].clone() intrinsic_ref[:2] *= args.imgScale_test / args.imgScale_train xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale, near=self.near_far_source[0], far=self.near_far_source[1], pad=args.pad * args.imgScale_test, lindisp=args.use_disp) # important sampleing if self.density_volume is not None and args.N_importance > 0: xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher_fine( rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk], self.density_volume, z_vals, xyz_NDC, N_importance=args.N_importance) xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale, near=self.near_far_source[0], far=self.near_far_source[1], pad=args.pad, lindisp=args.use_disp) # rendering rgb, disp, acc, depth_pred, alpha, extras = rendering( args, self.pose_source, xyz_coarse_sampled, xyz_NDC, z_vals, rays_o, rays_d, self.volume, self.imgs, **self.render_kwargs_train) rgbs.append(rgb.cpu()) depth_preds.append(depth_pred.cpu()) rgbs, depth_r = torch.clamp( torch.cat(rgbs).reshape(H, W, 3), 0, 1), torch.cat(depth_preds).reshape(H, W) img_err_abs = (rgbs - img).abs() log['val_psnr_all'] = mse2psnr(torch.mean(img_err_abs**2)) depth_r, _ = visualize_depth(depth_r, self.near_far_source) self.logger.experiment.add_images('val/depth_gt_pred', depth_r[None], self.global_step) img_vis = torch.stack( (img, rgbs, img_err_abs.cpu() * 5)).permute(0, 3, 1, 2) self.logger.experiment.add_images('val/rgb_pred_err', img_vis, self.global_step) os.makedirs( f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/', exist_ok=True) img_vis = torch.cat( (img, rgbs, img_err_abs * 10, depth_r.permute(1, 2, 0)), dim=1).numpy() imageio.imwrite( f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/{self.global_step:08d}_{self.idx:02d}.png', (img_vis * 255).astype('uint8')) self.idx += 1 return log
def training_step(self, batch, batch_nb): rays, rgbs_target = self.decode_batch(batch) if args.use_density_volume and 0 == self.global_step % 200: self.update_density_volume() xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher( rays, N_samples=args.N_samples, lindisp=args.use_disp, perturb=args.perturb) # Converting world coordinate to ndc coordinate H, W = self.imgs.shape[-2:] inv_scale = torch.tensor([W - 1, H - 1]).to(device) w2c_ref, intrinsic_ref = self.pose_source['w2cs'][0], self.pose_source[ 'intrinsics'][0] xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale, near=self.near_far_source[0], far=self.near_far_source[1], pad=args.pad, lindisp=args.use_disp) # important sampleing if self.density_volume is not None and args.N_importance > 0: xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher_fine( rays, self.density_volume, z_vals, xyz_NDC, N_importance=args.N_importance) xyz_NDC = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale, near=self.near_far_source[0], far=self.near_far_source[1], pad=args.pad, lindisp=args.use_disp) # rendering rgbs, disp, acc, depth_pred, alpha, extras = rendering( args, self.pose_source, xyz_coarse_sampled, xyz_NDC, z_vals, rays_o, rays_d, self.volume, self.imgs, **self.render_kwargs_train) log, loss = {}, 0 if 'rgb0' in extras: img_loss0 = img2mse(extras['rgb0'], rgbs_target) loss = loss + img_loss0 psnr0 = mse2psnr2(img_loss0.item()) self.log('train/PSNR0', psnr0.item(), prog_bar=True) ################## rendering ##################### if self.args.with_rgb_loss: img_loss = img2mse(rgbs, rgbs_target) loss += img_loss psnr = mse2psnr2(img_loss.item()) with torch.no_grad(): self.log('train/loss', loss, prog_bar=True) self.log('train/img_mse_loss', img_loss.item(), prog_bar=False) self.log('train/PSNR', psnr.item(), prog_bar=True) # if self.global_step == 3999 or self.global_step == 9999: # self.save_ckpt(f'{self.global_step}') return {'loss': loss}
def validation_step(self, batch, batch_nb): self.MVSNet.train() rays, img = self.decode_batch(batch) rays = rays.squeeze() # (H*W, 3) img = img.squeeze().cpu() # (H, W, 3) N_rays_all = rays.shape[0] ################## rendering ##################### keys = ['val_psnr_all'] log = init_log({}, keys) with torch.no_grad(): rgbs, depth_preds = [], [] for chunk_idx in range(N_rays_all // args.chunk + int(N_rays_all % args.chunk > 0)): xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher( rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk], lindisp=args.use_disp, N_samples=args.N_samples, N_importance=args.N_importance, density_volume=self.density_volume, bbox_3D=self.bbox_3d) # Converting world coordinate to ndc coordinate xyz_NDC = (xyz_coarse_sampled - self.bbox_3d[0].view( 1, 1, 3)) / (self.bbox_3d[1] - self.bbox_3d[0]).view( 1, 1, 3) # rendering rgb, disp, acc, depth_pred, alpha, extras = rendering( args, self.pose_source_ref, xyz_coarse_sampled, xyz_NDC, z_vals, rays_o, rays_d, self.volume, **self.render_kwargs_train) rgbs.append(rgb.cpu()) depth_preds.append(depth_pred.cpu()) H, W = img.shape[:2] rgbs, depth_r = torch.clamp( torch.cat(rgbs).reshape(H, W, 3), 0, 1), torch.cat(depth_preds).reshape(H, W) img_err_abs = (rgbs - img).abs() log['val_psnr_all'] = mse2psnr(torch.mean(img_err_abs**2)) depth_r, _ = visualize_depth(depth_r, self.near_far_source) self.logger.experiment.add_images('val/depth_gt_pred', depth_r[None], self.global_step) img_vis = torch.stack( (img, rgbs, img_err_abs.cpu() * 5)).permute(0, 3, 1, 2) self.logger.experiment.add_images('val/rgb_pred_err', img_vis, self.global_step) os.makedirs( f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/', exist_ok=True) img_vis = torch.cat( (img, rgbs, img_err_abs * 10, depth_r.permute(1, 2, 0)), dim=1).numpy() imageio.imwrite( f'runs_fine_tuning/{self.args.expname}/{self.args.expname}' f'/{self.args.expname}_{self.global_step:08d}_{self.idx:02d}.png', (img_vis * 255).astype('uint8')) self.idx += 1 return log
def fuse_local_volumes(self): feat_dim = 8 + 12 volume_dim = self.volume_dim canonical_sigma = torch.zeros( (1, 1, volume_dim[2], volume_dim[1], volume_dim[0])).to(device) canonical_weights = torch.zeros( (1, 1, volume_dim[2], volume_dim[1], volume_dim[0])).to(device) canonical_volume = torch.zeros( (1, feat_dim, volume_dim[2], volume_dim[1], volume_dim[0])).to(device) pairs = np.array(self.train_dataset.pair_idx[0]) c2w_render = self.train_dataset.load_poses_all()[pairs] W, H = self.train_dataset.img_wh H, W = H // 4, W // 4 img_directions = get_ray_directions( H, W, torch.tensor(self.train_dataset.focal) / 4.0).to(device) with torch.no_grad(): for i, c2w in enumerate(tqdm(c2w_render)): torch.cuda.empty_cache() # find nearest image idx from training views positions = c2w_render[:, :3, 3] dis = np.sum(np.abs(positions - c2w[:3, 3:].T), axis=-1) pair_idx = pairs[np.argsort(dis)[:3]] imgs_source, proj_mats, near_far_source, pose_source = self.train_dataset.read_source_views( pair_idx=pair_idx, device=device) volume_feature, _, _ = self.MVSNet(imgs_source, proj_mats, near_far_source, pad=args.pad) imgs_source = self.unpreprocess(imgs_source) if 0 == i: self.pose_source_ref = pose_source rays_o, rays_d = get_rays( img_directions, torch.from_numpy(c2w).float().to(device)) # both (h*w, 3) rays = torch.cat([ rays_o, rays_d, near_far_source[0] * torch.ones_like(rays_o[:, :1]), near_far_source[1] * torch.ones_like(rays_o[:, :1]) ], 1).to(device) # (H*W, 3) N_rays_all = rays.shape[0] rgb_rays, depth_rays_preds = [], [] for chunk_idx in range(N_rays_all // args.chunk + int(N_rays_all % args.chunk > 0)): xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher( rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk], N_samples=128) # Converting world coordinate to ndc coordinate inv_scale = torch.tensor([W - 1, H - 1]).to(device) w2c_ref, intrinsic_ref = pose_source['w2cs'][ 0], pose_source['intrinsics'][0].clone() intrinsic_ref[:2] *= 0.25 xyz_ndc = get_ndc_coordinate(w2c_ref, intrinsic_ref, xyz_coarse_sampled, inv_scale, near=near_far_source[0], far=near_far_source[1], pad=args.pad * 0.25) # rendering rgb, ray_feat, ray_weight, depth_pred, ray_sigma, _ = rendering( args, pose_source, xyz_coarse_sampled, xyz_ndc, z_vals, rays_o, rays_d, volume_feature, imgs_source, **self.render_kwargs_train) ray_ndc = (xyz_coarse_sampled - self.bbox_3d[0].view( 1, 1, 3)) / (self.bbox_3d[1] - self.bbox_3d[0]).view( 1, 1, 3) update_volume(canonical_volume, canonical_sigma, canonical_weights, ray_feat, ray_ndc, ray_sigma, ray_weight) # rgb, depth_pred = torch.clamp(rgb.cpu(), 0, 1.0).numpy(), depth_pred.cpu().numpy() # rgb_rays.append(rgb) # depth_rays_preds.append(depth_pred) # # depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W) # depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source) # # rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3) # img_vis = np.concatenate((rgb_rays * 255, depth_rays_preds), axis=1) # imageio.imwrite(f'/mnt/new_disk2/anpei/code/MVS-NeRF/results/test4/{i:03d}.png', img_vis.astype('uint8')) canonical_weights = 1.0 / (canonical_weights + 1e-6) canonical_volume = canonical_volume * canonical_weights canonical_sigma = canonical_sigma * canonical_weights # mask = canonical_weights > 0 # weights = canonical_weights.clone() # weights[mask] = 1.0 / weights[mask] # canonical_volume = canonical_volume * weights self.density_volume = canonical_sigma self.volume = RefVolume(canonical_volume).to(device) del canonical_volume, canonical_weights torch.cuda.empty_cache()