示例#1
0
    def training_step(self, batch, batch_nb):

        rays, rgbs_target = self.decode_batch(batch)

        if args.N_importance and 0 == self.global_step % 500:
            self.update_density_volume()

        xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(
            rays,
            N_samples=args.N_samples,
            N_importance=args.N_importance,
            lindisp=args.use_disp,
            perturb=args.perturb,
            density_volume=self.density_volume,
            bbox_3D=self.bbox_3d)

        # Converting world coordinate to ndc coordinate
        xyz_NDC = (xyz_coarse_sampled - self.bbox_3d[0].view(1, 1, 3)) / (
            self.bbox_3d[1] - self.bbox_3d[0]).view(1, 1, 3)

        # rendering
        rgbs, disp, acc, depth_pred, alpha, extras = rendering(
            args, self.pose_source_ref, xyz_coarse_sampled, xyz_NDC, z_vals,
            rays_o, rays_d, self.volume, **self.render_kwargs_train)

        log, loss = {}, 0
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], rgbs_target)
            loss = loss + img_loss0
            psnr0 = mse2psnr2(img_loss0.item())
            self.log('train/PSNR0', psnr0.item(), prog_bar=True)

        ##################  rendering #####################
        if self.args.with_rgb_loss:
            img_loss = img2mse(rgbs, rgbs_target)
            loss += img_loss
            psnr = mse2psnr2(img_loss.item())

            with torch.no_grad():
                self.log('train/loss', loss, prog_bar=True)
                self.log('train/img_mse_loss', img_loss.item(), prog_bar=False)
                self.log('train/PSNR', psnr.item(), prog_bar=True)

        # if self.global_step == 3999 or self.global_step == 9999:
        #     self.save_ckpt(f'{self.global_step}')

        return {'loss': loss}
示例#2
0
    def validation_step(self, batch, batch_nb):

        self.MVSNet.train()
        rays, img = self.decode_batch(batch)
        img = img.cpu()  # (H, W, 3)
        # mask = batch['mask'][0]

        N_rays_all = rays.shape[0]

        ##################  rendering #####################
        keys = ['val_psnr_all']
        log = init_log({}, keys)
        with torch.no_grad():

            rgbs, depth_preds = [], []
            for chunk_idx in range(N_rays_all // args.chunk +
                                   int(N_rays_all % args.chunk > 0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(
                    rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk],
                    N_samples=args.N_samples,
                    lindisp=args.use_disp)

                # Converting world coordinate to ndc coordinate
                H, W = img.shape[:2]
                inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                w2c_ref, intrinsic_ref = self.pose_source['w2cs'][
                    0], self.pose_source['intrinsics'][0].clone()
                intrinsic_ref[:2] *= args.imgScale_test / args.imgScale_train
                xyz_NDC = get_ndc_coordinate(w2c_ref,
                                             intrinsic_ref,
                                             xyz_coarse_sampled,
                                             inv_scale,
                                             near=self.near_far_source[0],
                                             far=self.near_far_source[1],
                                             pad=args.pad * args.imgScale_test,
                                             lindisp=args.use_disp)

                # important sampleing
                if self.density_volume is not None and args.N_importance > 0:
                    xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher_fine(
                        rays[chunk_idx * args.chunk:(chunk_idx + 1) *
                             args.chunk],
                        self.density_volume,
                        z_vals,
                        xyz_NDC,
                        N_importance=args.N_importance)
                    xyz_NDC = get_ndc_coordinate(w2c_ref,
                                                 intrinsic_ref,
                                                 xyz_coarse_sampled,
                                                 inv_scale,
                                                 near=self.near_far_source[0],
                                                 far=self.near_far_source[1],
                                                 pad=args.pad,
                                                 lindisp=args.use_disp)

                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(
                    args, self.pose_source, xyz_coarse_sampled, xyz_NDC,
                    z_vals, rays_o, rays_d, self.volume, self.imgs,
                    **self.render_kwargs_train)

                rgbs.append(rgb.cpu())
                depth_preds.append(depth_pred.cpu())

            rgbs, depth_r = torch.clamp(
                torch.cat(rgbs).reshape(H, W, 3), 0,
                1), torch.cat(depth_preds).reshape(H, W)
            img_err_abs = (rgbs - img).abs()

            log['val_psnr_all'] = mse2psnr(torch.mean(img_err_abs**2))
            depth_r, _ = visualize_depth(depth_r, self.near_far_source)
            self.logger.experiment.add_images('val/depth_gt_pred',
                                              depth_r[None], self.global_step)

            img_vis = torch.stack(
                (img, rgbs, img_err_abs.cpu() * 5)).permute(0, 3, 1, 2)
            self.logger.experiment.add_images('val/rgb_pred_err', img_vis,
                                              self.global_step)
            os.makedirs(
                f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/',
                exist_ok=True)

            img_vis = torch.cat(
                (img, rgbs, img_err_abs * 10, depth_r.permute(1, 2, 0)),
                dim=1).numpy()
            imageio.imwrite(
                f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/{self.global_step:08d}_{self.idx:02d}.png',
                (img_vis * 255).astype('uint8'))
            self.idx += 1

        return log
示例#3
0
    def training_step(self, batch, batch_nb):

        rays, rgbs_target = self.decode_batch(batch)

        if args.use_density_volume and 0 == self.global_step % 200:
            self.update_density_volume()

        xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(
            rays,
            N_samples=args.N_samples,
            lindisp=args.use_disp,
            perturb=args.perturb)

        # Converting world coordinate to ndc coordinate
        H, W = self.imgs.shape[-2:]
        inv_scale = torch.tensor([W - 1, H - 1]).to(device)
        w2c_ref, intrinsic_ref = self.pose_source['w2cs'][0], self.pose_source[
            'intrinsics'][0]
        xyz_NDC = get_ndc_coordinate(w2c_ref,
                                     intrinsic_ref,
                                     xyz_coarse_sampled,
                                     inv_scale,
                                     near=self.near_far_source[0],
                                     far=self.near_far_source[1],
                                     pad=args.pad,
                                     lindisp=args.use_disp)

        # important sampleing
        if self.density_volume is not None and args.N_importance > 0:
            xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher_fine(
                rays,
                self.density_volume,
                z_vals,
                xyz_NDC,
                N_importance=args.N_importance)
            xyz_NDC = get_ndc_coordinate(w2c_ref,
                                         intrinsic_ref,
                                         xyz_coarse_sampled,
                                         inv_scale,
                                         near=self.near_far_source[0],
                                         far=self.near_far_source[1],
                                         pad=args.pad,
                                         lindisp=args.use_disp)

        # rendering
        rgbs, disp, acc, depth_pred, alpha, extras = rendering(
            args, self.pose_source, xyz_coarse_sampled, xyz_NDC, z_vals,
            rays_o, rays_d, self.volume, self.imgs, **self.render_kwargs_train)

        log, loss = {}, 0
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], rgbs_target)
            loss = loss + img_loss0
            psnr0 = mse2psnr2(img_loss0.item())
            self.log('train/PSNR0', psnr0.item(), prog_bar=True)

        ##################  rendering #####################
        if self.args.with_rgb_loss:
            img_loss = img2mse(rgbs, rgbs_target)
            loss += img_loss
            psnr = mse2psnr2(img_loss.item())

            with torch.no_grad():
                self.log('train/loss', loss, prog_bar=True)
                self.log('train/img_mse_loss', img_loss.item(), prog_bar=False)
                self.log('train/PSNR', psnr.item(), prog_bar=True)

        # if self.global_step == 3999 or self.global_step == 9999:
        #     self.save_ckpt(f'{self.global_step}')

        return {'loss': loss}
示例#4
0
    def validation_step(self, batch, batch_nb):

        self.MVSNet.train()
        rays, img = self.decode_batch(batch)
        rays = rays.squeeze()  # (H*W, 3)
        img = img.squeeze().cpu()  # (H, W, 3)

        N_rays_all = rays.shape[0]

        ##################  rendering #####################
        keys = ['val_psnr_all']
        log = init_log({}, keys)
        with torch.no_grad():

            rgbs, depth_preds = [], []
            for chunk_idx in range(N_rays_all // args.chunk +
                                   int(N_rays_all % args.chunk > 0)):

                xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(
                    rays[chunk_idx * args.chunk:(chunk_idx + 1) * args.chunk],
                    lindisp=args.use_disp,
                    N_samples=args.N_samples,
                    N_importance=args.N_importance,
                    density_volume=self.density_volume,
                    bbox_3D=self.bbox_3d)

                # Converting world coordinate to ndc coordinate
                xyz_NDC = (xyz_coarse_sampled - self.bbox_3d[0].view(
                    1, 1, 3)) / (self.bbox_3d[1] - self.bbox_3d[0]).view(
                        1, 1, 3)

                # rendering
                rgb, disp, acc, depth_pred, alpha, extras = rendering(
                    args, self.pose_source_ref, xyz_coarse_sampled, xyz_NDC,
                    z_vals, rays_o, rays_d, self.volume,
                    **self.render_kwargs_train)

                rgbs.append(rgb.cpu())
                depth_preds.append(depth_pred.cpu())

            H, W = img.shape[:2]
            rgbs, depth_r = torch.clamp(
                torch.cat(rgbs).reshape(H, W, 3), 0,
                1), torch.cat(depth_preds).reshape(H, W)
            img_err_abs = (rgbs - img).abs()

            log['val_psnr_all'] = mse2psnr(torch.mean(img_err_abs**2))
            depth_r, _ = visualize_depth(depth_r, self.near_far_source)
            self.logger.experiment.add_images('val/depth_gt_pred',
                                              depth_r[None], self.global_step)

            img_vis = torch.stack(
                (img, rgbs, img_err_abs.cpu() * 5)).permute(0, 3, 1, 2)
            self.logger.experiment.add_images('val/rgb_pred_err', img_vis,
                                              self.global_step)
            os.makedirs(
                f'runs_fine_tuning/{self.args.expname}/{self.args.expname}/',
                exist_ok=True)

            img_vis = torch.cat(
                (img, rgbs, img_err_abs * 10, depth_r.permute(1, 2, 0)),
                dim=1).numpy()
            imageio.imwrite(
                f'runs_fine_tuning/{self.args.expname}/{self.args.expname}'
                f'/{self.args.expname}_{self.global_step:08d}_{self.idx:02d}.png',
                (img_vis * 255).astype('uint8'))
            self.idx += 1

        return log
示例#5
0
    def fuse_local_volumes(self):

        feat_dim = 8 + 12
        volume_dim = self.volume_dim

        canonical_sigma = torch.zeros(
            (1, 1, volume_dim[2], volume_dim[1], volume_dim[0])).to(device)
        canonical_weights = torch.zeros(
            (1, 1, volume_dim[2], volume_dim[1], volume_dim[0])).to(device)
        canonical_volume = torch.zeros(
            (1, feat_dim, volume_dim[2], volume_dim[1],
             volume_dim[0])).to(device)

        pairs = np.array(self.train_dataset.pair_idx[0])
        c2w_render = self.train_dataset.load_poses_all()[pairs]

        W, H = self.train_dataset.img_wh
        H, W = H // 4, W // 4
        img_directions = get_ray_directions(
            H, W,
            torch.tensor(self.train_dataset.focal) / 4.0).to(device)

        with torch.no_grad():
            for i, c2w in enumerate(tqdm(c2w_render)):
                torch.cuda.empty_cache()

                # find nearest image idx from training views
                positions = c2w_render[:, :3, 3]
                dis = np.sum(np.abs(positions - c2w[:3, 3:].T), axis=-1)
                pair_idx = pairs[np.argsort(dis)[:3]]

                imgs_source, proj_mats, near_far_source, pose_source = self.train_dataset.read_source_views(
                    pair_idx=pair_idx, device=device)
                volume_feature, _, _ = self.MVSNet(imgs_source,
                                                   proj_mats,
                                                   near_far_source,
                                                   pad=args.pad)
                imgs_source = self.unpreprocess(imgs_source)
                if 0 == i:
                    self.pose_source_ref = pose_source

                rays_o, rays_d = get_rays(
                    img_directions,
                    torch.from_numpy(c2w).float().to(device))  # both (h*w, 3)
                rays = torch.cat([
                    rays_o, rays_d,
                    near_far_source[0] * torch.ones_like(rays_o[:, :1]),
                    near_far_source[1] * torch.ones_like(rays_o[:, :1])
                ], 1).to(device)  # (H*W, 3)

                N_rays_all = rays.shape[0]
                rgb_rays, depth_rays_preds = [], []
                for chunk_idx in range(N_rays_all // args.chunk +
                                       int(N_rays_all % args.chunk > 0)):

                    xyz_coarse_sampled, rays_o, rays_d, z_vals = ray_marcher(
                        rays[chunk_idx * args.chunk:(chunk_idx + 1) *
                             args.chunk],
                        N_samples=128)

                    # Converting world coordinate to ndc coordinate
                    inv_scale = torch.tensor([W - 1, H - 1]).to(device)
                    w2c_ref, intrinsic_ref = pose_source['w2cs'][
                        0], pose_source['intrinsics'][0].clone()
                    intrinsic_ref[:2] *= 0.25

                    xyz_ndc = get_ndc_coordinate(w2c_ref,
                                                 intrinsic_ref,
                                                 xyz_coarse_sampled,
                                                 inv_scale,
                                                 near=near_far_source[0],
                                                 far=near_far_source[1],
                                                 pad=args.pad * 0.25)

                    # rendering
                    rgb, ray_feat, ray_weight, depth_pred, ray_sigma, _ = rendering(
                        args, pose_source, xyz_coarse_sampled, xyz_ndc, z_vals,
                        rays_o, rays_d, volume_feature, imgs_source,
                        **self.render_kwargs_train)

                    ray_ndc = (xyz_coarse_sampled - self.bbox_3d[0].view(
                        1, 1, 3)) / (self.bbox_3d[1] - self.bbox_3d[0]).view(
                            1, 1, 3)
                    update_volume(canonical_volume, canonical_sigma,
                                  canonical_weights, ray_feat, ray_ndc,
                                  ray_sigma, ray_weight)

                #     rgb, depth_pred = torch.clamp(rgb.cpu(), 0, 1.0).numpy(), depth_pred.cpu().numpy()
                #     rgb_rays.append(rgb)
                #     depth_rays_preds.append(depth_pred)
                #
                # depth_rays_preds = np.concatenate(depth_rays_preds).reshape(H, W)
                # depth_rays_preds, _ = visualize_depth_numpy(depth_rays_preds, near_far_source)
                #
                # rgb_rays = np.concatenate(rgb_rays).reshape(H, W, 3)
                # img_vis = np.concatenate((rgb_rays * 255, depth_rays_preds), axis=1)
                # imageio.imwrite(f'/mnt/new_disk2/anpei/code/MVS-NeRF/results/test4/{i:03d}.png', img_vis.astype('uint8'))

        canonical_weights = 1.0 / (canonical_weights + 1e-6)
        canonical_volume = canonical_volume * canonical_weights
        canonical_sigma = canonical_sigma * canonical_weights

        # mask = canonical_weights > 0
        # weights = canonical_weights.clone()
        # weights[mask] = 1.0 / weights[mask]
        # canonical_volume = canonical_volume * weights

        self.density_volume = canonical_sigma
        self.volume = RefVolume(canonical_volume).to(device)

        del canonical_volume, canonical_weights
        torch.cuda.empty_cache()