Example #1
0
    def _render_reprojections(self, batch):
        recon_camera = Camera.vcat(
            (batch['in_gt']['camera'], batch['out_gt']['camera']),
            batch_size=self.batch_size)

        if self._recon_params['generator_input_depth']:
            depth_noise = self._depth_noise_dist.sample(
                batch['in']['depth'].size()).to(self.device)
            depth_real_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1)
        else:
            depth_real_in = None

        # Create IBR images.
        with torch.set_grad_enabled(self.train_recon):
            z_obj, z_extra = self._sculptor.encode(
                self._fuser,
                camera=batch['in']['camera'],
                color=batch['in']['image'],
                depth=depth_real_in,
                mask=batch['in']['mask'],
                data_parallel=self.data_parallel)

            fake, _, _ = self._photographer.decode(
                z_obj, recon_camera, data_parallel=self.data_parallel)

            # Reshape things to BVCHW.
            sections = (self.num_input_views, self.num_output_views)
            depth_fake_in, depth_fake_out = torch.split(fake['depth'],
                                                        sections,
                                                        dim=1)
            mask_fake_in, mask_fake_out = torch.split(fake['mask'],
                                                      sections,
                                                      dim=1)

            image_reproj, depth_reproj, cam_dists_r, cam_dists_t = ibr.reproject_views_batch(
                image_in=batch['in']['image'],
                depth_in=depth_fake_in,
                depth_out=depth_fake_out,
                camera_in=batch['in']['camera'],
                camera_out=batch['out_gt']['camera'])
            image_reproj = image_reproj * mask_fake_out.unsqueeze(2)
            depth_reproj = (depth_reproj +
                            1.0) * mask_fake_out.unsqueeze(2) - 1.0

        return (
            image_reproj,
            depth_reproj,
            mask_fake_out.contiguous(),
            depth_fake_out.contiguous(),
            cam_dists_r,
            cam_dists_t,
        )
    def run_iteration(self, batch, train, is_step):
        self.mark_time()
        # Update depth criterion k if applicable.
        if 'hard_' in self.g_depth_recon_loss_type:
            self._g_depth_recon_criterion.k = int(
                self._g_depth_recon_k_scheduler.get(self.epoch))

        batch = process_batch(batch, self.cube_size, self.camera_dist,
                              self._sculptor.in_size, self.device,
                              self.random_orientation)

        if self.reconstruct_input:
            recon_camera = Camera.vcat(
                (batch['in_gt']['camera'], batch['out_gt']['camera']),
                batch_size=self.batch_size)
            recon_mask = torch.cat(
                (batch['in_gt']['mask'], batch['out_gt']['mask']), dim=1)
            recon_image = torch.cat(
                (batch['in_gt']['image'], batch['out_gt']['image']), dim=1)
            recon_depth = torch.cat(
                (batch['in_gt']['depth'], batch['out_gt']['depth']), dim=1)
        else:
            recon_camera = batch['out_gt']['camera']
            recon_mask = batch['out_gt']['mask']
            recon_image = batch['out_gt']['image']
            recon_depth = batch['out_gt']['depth']

        if not self.color_random_background or self.crop_random_background:
            batch['in']['image'] = batch['in']['image'] * batch['in']['mask']

        if not self.depth_random_background or self.crop_random_background:
            batch['in']['depth'] = mask_normalized_depth(
                batch['in']['depth'], batch['in']['mask'])

        depth_in = None
        if self.generator_input_depth:
            depth_noise = self._depth_noise_dist.sample(
                batch['in']['depth'].size()).to(self.device)
            depth_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1)

        data_process_time = self.mark_time()

        with autocast():
            # Evaluate generator.
            z_obj, z_extra = self._sculptor.encode(
                self._fuser,
                camera=batch['in']['camera'],
                color=batch['in']['image'],
                depth=depth_in,
                mask=batch['in']['mask'],
                data_parallel=self.data_parallel)
            fake_image, fake_depth, fake_mask, fake_mask_logits, fake_vox_depth = \
                self._run_photographer(z_obj, recon_camera, recon_mask)

            if 'blend_weights' in z_extra:
                z_weights = z_extra['blend_weights']
            else:
                z_weights = None

            # Train discriminator.
            if self._discriminator:
                d_real, d_fake_d, d_fake_g = self._run_discriminator(
                    fake_image, fake_depth, fake_mask, recon_image,
                    recon_depth, recon_mask)
                loss_d_real = multiscale_lsgan_loss(d_real, 1)
                loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0)
                loss_d = loss_d_real + loss_d_fake
                loss_g_gan = multiscale_lsgan_loss(d_fake_g, 1)

                if train:
                    loss_d.backward()
                    if is_step:
                        self._optimizers['discriminator'].step()

                self.plotter.put_scalar('loss/discriminator/real', loss_d_real)
                self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake)
                self.plotter.put_scalar('loss/discriminator/total', loss_d)
            else:
                loss_g_gan = torch.tensor(0.0, device=self.device)

            # Train generator.
            if self.predict_color:
                loss_g_color_recon = reduce_loss(
                    self._g_color_recon_criterion(fake_image, recon_image))
            else:
                loss_g_color_recon = torch.tensor(0.0, device=self.device)

            if self.predict_depth or self.use_occlusion_depth:
                loss_g_depth_recon = reduce_loss(
                    self._g_depth_recon_criterion(fake_depth, recon_depth))
            else:
                loss_g_depth_recon = torch.tensor(0.0, device=self.device)

            if self.predict_mask:
                if self.g_mask_recon_loss_type == 'binary_cross_entropy':
                    y_mask = fake_mask_logits
                else:
                    y_mask = fake_mask
                loss_g_mask_recon = reduce_loss(
                    self._g_mask_recon_criterion(y_mask, recon_mask))
                loss_g_mask_beta = beta_prior_loss(
                    fake_mask,
                    alpha=self.g_mask_beta_loss_param,
                    beta=self.g_mask_beta_loss_param)
            else:
                loss_g_mask_recon = torch.tensor(0.0, device=self.device)
                loss_g_mask_beta = torch.tensor(0.0, device=self.device)

            loss_g = (self.g_gan_loss_weight * loss_g_gan +
                      self.g_color_recon_loss_weight * loss_g_color_recon +
                      self.g_depth_recon_loss_weight * loss_g_depth_recon +
                      self.g_mask_recon_loss_weight * loss_g_mask_recon +
                      self.g_mask_beta_loss_weight *
                      loss_g_mask_beta) / self.batch_groups

        if train:
            if self.kwargs.get('use_amp', False):
                self._scaler.scale(loss_g).backward()
            else:
                loss_g.backward()

            if is_step:
                if self.kwargs.get('use_amp', False):
                    self._scaler.step(self._optimizers['generator'])
                    self._scaler.update()
                else:
                    self._optimizers['generator'].step()

        with torch.no_grad():
            if self.predict_depth:
                self.plotter.put_scalar('error/depth/l1',
                                        F.l1_loss(fake_depth, recon_depth))
            if self.reconstruct_input:
                self.plotter.put_scalar(
                    'error/depth/input_l1',
                    F.l1_loss(fake_depth[:, :self.num_input_views],
                              batch['in_gt']['depth']))
                self.plotter.put_scalar(
                    'error/depth/output_l1',
                    F.l1_loss(fake_depth[:, self.num_input_views:],
                              batch['out_gt']['depth']))
            if self.predict_mask:
                self.plotter.put_scalar(
                    'error/mask/cross_entropy',
                    F.binary_cross_entropy_with_logits(fake_mask_logits,
                                                       recon_mask))
                self.plotter.put_scalar('error/mask/l1',
                                        F.l1_loss(fake_mask, recon_mask))

        compute_time = self.mark_time()

        self.plotter.put_scalar('loss/generator/gan', loss_g_gan)
        self.plotter.put_scalar('loss/generator/recon/color',
                                loss_g_color_recon)
        self.plotter.put_scalar('loss/generator/recon/depth',
                                loss_g_depth_recon)
        self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon)
        self.plotter.put_scalar('loss/generator/recon/mask_beta',
                                loss_g_mask_beta)
        self.plotter.put_scalar('loss/generator/total', loss_g)

        self.plotter.put_scalar('params/input_noise_weight',
                                self.input_noise_weight)
        if hasattr(self._g_depth_recon_criterion, 'k'):
            self.plotter.put_scalar('params/depth_loss_k',
                                    self._g_depth_recon_criterion.k)

        self.plotter.put_scalar('time/data_process', data_process_time)
        self.plotter.put_scalar('time/compute', compute_time)
        plot_scalar_time = self.mark_time()
        self.plotter.put_scalar('time/plot/scalars', plot_scalar_time)

        if self.plotter.is_it_time_yet('histogram'):
            if self.predict_color:
                self.plotter.put_histogram('image_fake', fake_image)
                self.plotter.put_histogram('image_real', recon_image)
            if self.predict_mask:
                self.plotter.put_histogram('mask_fake', fake_mask)
            self.plotter.put_histogram('z_obj', z_obj)
            if z_weights is not None:
                self.plotter.put_histogram('z_weights', z_weights)
        plot_histogram_time = self.mark_time()
        self.plotter.put_scalar('time/plot/histogram', plot_histogram_time)

        if self.plotter.is_it_time_yet('show'):
            self.plotter.put_image(
                'inputs',
                viz.make_grid([
                    gan_denormalize(batch['in']['image']),
                    viz.colorize_depth(batch['in']['depth'])
                    if self.generator_input_depth else None,
                    viz.colorize_tensor(batch['in']['mask'])
                    if self.generator_input_mask else None,
                ],
                              row_size=4,
                              stride=2,
                              output_size=64))
            with torch.no_grad():
                self.plotter.put_image(
                    'reconstruction',
                    viz.make_grid([
                        gan_denormalize(recon_image),
                        gan_denormalize(fake_image) if
                        (fake_image is not None) else None,
                        viz.colorize_depth(recon_depth),
                        viz.colorize_depth(fake_depth) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(
                            (recon_depth.cpu() - fake_depth.cpu()).abs()) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(recon_mask),
                        viz.colorize_tensor(fake_mask) if
                        (fake_mask is not None) else None,
                        viz.colorize_tensor(
                            (recon_mask.cpu() - fake_mask.cpu()).abs()) if
                        (fake_mask is not None) else None,
                    ],
                                  stride=8))
        plot_images_time = self.mark_time()
        self.plotter.put_scalar('time/plot/images', plot_images_time)