Example #1
0
    def _estimate(self, z_obj, target_obs, **kwargs):
        if 'camera' in kwargs:
            camera = kwargs['camera']
        else:
            camera = self.initial_pose(target_obs)
            camera = pu.sample_cameras_with_estimate(
                n=self.num_samples, camera_est=camera)
        target_obs = target_obs.to(self.device)

        # Mask parameters.

        # Optimize the 'zoomed' camera.
        camera = camera.zoom(None, self.model.input_size, self.model.camera_dist).to(self.device)

        ranking = []
        stat_history, camera_history = self._optimize_camera(
            z_obj, target_obs, camera,
            iters=self.num_iters,
            ranking=ranking)

        logger.info('best camera', step=ranking[0][2], loss=ranking[0][1])
        best_cameras = Camera.cat([c for c, loss, step in ranking])

        if self.track_stats and self.return_camera_history:
            return best_cameras, stat_history, camera_history
        elif self.track_stats:
            return best_cameras, stat_history
        elif self.return_camera_history:
            return best_cameras, camera_history
        else:
            return best_cameras
def parse_poserbpf_cameras(seq_path):
    mat_paths = sorted(seq_path.glob('*.mat'))
    #     print(mat_paths)
    inds = [int(x.name.split('.')[0]) for x in mat_paths]
    cameras = []
    for mat in mat_paths:
        cameras.append(load_poserbpf_camera(mat))
    return Camera.cat(cameras)
def load_poserbpf_camera(mat_path, key='poses'):
    mat = loadmat(mat_path)
    intrinsic = torch.tensor(mat['intrinsic_matrix']).float()
    pose = torch.tensor(mat[key]).squeeze().float()
    quat = pose[:4]
    translation = pose[4:]
    extrinsic = three.to_extrinsic_matrix(translation, quat)
    camera = Camera(intrinsic=intrinsic, extrinsic=extrinsic)
    return camera
Example #4
0
def item_to_obs(item):
    height, width = item['color'].shape[-2:]
    return latentfusion.observation.Observation(item['color'].unsqueeze(0),
                                                item['depth'].unsqueeze(0).unsqueeze(0),
                                                item['mask'].unsqueeze(0).unsqueeze(0).float(),
                                                Camera(intrinsic=item['intrinsic'],
                                        extrinsic=item['extrinsic'],
                                        width=width,
                                        height=height))
Example #5
0
def estimate_initial_pose(depth, mask, intrinsic, width, height) -> Camera:
    """Estimate the initial pose based on depth."""
    translation = torch.stack(estimate_translation(depth, mask, intrinsic), dim=-1)
    rotation = three.quaternion.identity(intrinsic.shape[0], intrinsic.device)
    extrinsic = three.to_extrinsic_matrix(translation, rotation)

    camera = Camera(intrinsic, extrinsic, height=height, width=width)

    return camera
Example #6
0
def _process_batch(batch, rotation, cube_size, camera_dist, input_size, device,
                   is_gt):
    # Collapse viewpoint dimension to batch dimension:
    #   (B, V, C, H, W) => (B*V, C, H, W)
    batch_size = batch['mask'].shape[0]
    extrinsic = bv2b(batch['extrinsic'].to(device))
    intrinsic = bv2b(batch['intrinsic'].to(device))
    mask = bv2b(batch['mask'].unsqueeze(2).float().to(device))
    image = bv2b(gan_normalize(batch['render'].to(device)))
    if 'depth' in batch:
        depth = bv2b(batch['depth'].unsqueeze(2).to(device))
    else:
        depth = None

    # Project image features onto canonical volume.
    camera = Camera(intrinsic,
                    extrinsic,
                    z_span=cube_size / 2.0,
                    height=image.size(2),
                    width=image.size(3)).to(device)
    if rotation is not None:
        camera.rotate(rotation.expand(camera.length, -1))
        # translation = three.uniform(3, -cube_size/16, cube_size/16).view(1, 3).expand(camera.length, -1).to(device)
        # camera.translate(translation)
    _zoom = functools.partial(camera.zoom,
                              target_size=input_size,
                              target_dist=camera_dist)

    out = dict()
    # Zoom camera to canonical distance and size.
    out['image'], out['camera'] = _zoom(image, scale_mode='bilinear')
    out['mask'] = _zoom(mask, scale_mode='nearest')[0]
    if depth is not None:
        out['depth'] = camera.normalize_depth(
            _zoom(depth, scale_mode='nearest')[0])

    if is_gt:
        out['image'] = out['image'] * out['mask']
        out['depth'] = mask_normalized_depth(out['depth'], out['mask'])

    for k in {'image', 'depth', 'mask'}:
        out[k] = b2bv(out[k], batch_size=batch_size)

    return out
Example #7
0
    def _estimate(self, z_obj, target_obs, **kwargs):
        if kwargs.get('cameras', None):
            cameras = kwargs['cameras']
            camera_init = kwargs['cameras'][0]
        else:
            camera_init = self.initial_pose(target_obs)
            cameras = pu.sample_cameras_with_estimate(
                n=self.num_gmm_components * self.num_samples,
                camera_est=camera_init,
                upright=self.init_upright,
                hemisphere=self.init_hemisphere)

        gmm = self._create_gmm(self._camera_to_params(cameras))
        target_obs = target_obs.to(self.device)
        camera_history = []

        prev_gmm = None
        ranking = []
        pbar = utils.trange(self.num_iters)
        for step in pbar:
            # Refine pose.
            _num_elites = int(self.elite_sched.get(step))
            cameras, losses = self._refine_pose(z_obj, target_obs, prev_gmm, gmm,
                                                num_elites=_num_elites,
                                                camera_init=camera_init)
            prev_gmm = gmm
            gmm = self._create_gmm(self._camera_to_params(cameras).cpu())
            delta = self._track_best_items(ranking, step, cameras, losses)
            if delta > 0:
                camera_history.append((losses, Camera.cat([c for c, e, step in ranking])))
            pbar.set_description(f"best_error={ranking[0][1]:.05f}, num_elite={_num_elites}")

        # gmm_camera = self._params_to_camera(torch.tensor(gmm.means_, dtype=torch.float32),
        #                                     camera_init=camera_init)

        logger.info('best camera', step=ranking[0][2], loss=ranking[0][1])

        cameras = Camera.cat([c for c, e, step in ranking])
        if self.return_camera_history:
            return cameras, camera_history
        else:
            return cameras
 def from_dict(cls, d):
     height, width = d['color'].shape[-2:]
     camera = Camera(d['intrinsic'],
                     d['extrinsic'],
                     width=width,
                     height=height)
     return cls(
         d['color'],
         d['depth'].unsqueeze(-3),  # Create channel dimension.
         d['mask'].unsqueeze(-3).float(),
         camera)
def render_observation(renderer, scene):
    color, depth, mask = renderer.render(scene)
    camera = Camera(scene.intrinsic,
                    scene.extrinsic,
                    width=renderer.width,
                    height=renderer.height)

    return Observation(color.permute(2, 0, 1).unsqueeze(0),
                       depth.unsqueeze(0).unsqueeze(0),
                       mask.unsqueeze(0).unsqueeze(0),
                       camera,
                       object_scale=scene.obj.scale)
Example #10
0
    def zoom(self, target_dist, target_size, camera: Camera = None):
        if camera is None:
            camera = self.camera

        color, new_camera = camera.zoom(self.color,
                                        target_size,
                                        target_dist,
                                        scale_mode='bilinear')
        depth, _ = camera.zoom(self.depth,
                               target_size,
                               target_dist,
                               scale_mode='nearest')
        mask, _ = camera.zoom(self.mask,
                              target_size,
                              target_dist,
                              scale_mode='nearest')

        kwargs = copy.deepcopy(self.meta)
        kwargs['is_zoomed'] = True

        return Observation(color, depth, mask, new_camera, **kwargs)
Example #11
0
    def _render_reprojections(self, batch):
        recon_camera = Camera.vcat(
            (batch['in_gt']['camera'], batch['out_gt']['camera']),
            batch_size=self.batch_size)

        if self._recon_params['generator_input_depth']:
            depth_noise = self._depth_noise_dist.sample(
                batch['in']['depth'].size()).to(self.device)
            depth_real_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1)
        else:
            depth_real_in = None

        # Create IBR images.
        with torch.set_grad_enabled(self.train_recon):
            z_obj, z_extra = self._sculptor.encode(
                self._fuser,
                camera=batch['in']['camera'],
                color=batch['in']['image'],
                depth=depth_real_in,
                mask=batch['in']['mask'],
                data_parallel=self.data_parallel)

            fake, _, _ = self._photographer.decode(
                z_obj, recon_camera, data_parallel=self.data_parallel)

            # Reshape things to BVCHW.
            sections = (self.num_input_views, self.num_output_views)
            depth_fake_in, depth_fake_out = torch.split(fake['depth'],
                                                        sections,
                                                        dim=1)
            mask_fake_in, mask_fake_out = torch.split(fake['mask'],
                                                      sections,
                                                      dim=1)

            image_reproj, depth_reproj, cam_dists_r, cam_dists_t = ibr.reproject_views_batch(
                image_in=batch['in']['image'],
                depth_in=depth_fake_in,
                depth_out=depth_fake_out,
                camera_in=batch['in']['camera'],
                camera_out=batch['out_gt']['camera'])
            image_reproj = image_reproj * mask_fake_out.unsqueeze(2)
            depth_reproj = (depth_reproj +
                            1.0) * mask_fake_out.unsqueeze(2) - 1.0

        return (
            image_reproj,
            depth_reproj,
            mask_fake_out.contiguous(),
            depth_fake_out.contiguous(),
            cam_dists_r,
            cam_dists_t,
        )
Example #12
0
    def _params_to_camera(cls, params, camera_init, device='cpu'):
        if len(params.shape) == 1:
            params = params.unsqueeze(0)

        intrinsic = camera_init.intrinsic.expand(params.shape[0], -1, -1).to(device)
        translations = params[:, :3].to(device)
        log_quaternions = params[:, 3:].to(device)
        cameras = Camera(intrinsic=intrinsic,
                         extrinsic=None,
                         translation=translations,
                         log_quaternion=log_quaternions,
                         width=camera_init.width,
                         height=camera_init.height,
                         z_span=camera_init.z_span).to(device)
        return cameras
Example #13
0
    def load(cls, path, frames=None) -> 'Observation':
        if isinstance(path, str):
            path = Path(path)

        with open(path / f'cameras.json', 'r') as f:
            camera_json = json.load(f)
        if 'meta' in camera_json:
            meta = camera_json.pop('meta')
        else:
            meta = {}

        cameras = Camera(
            **{
                k: torch.tensor(v, dtype=torch.float32) if isinstance(v, list
                                                                      ) else v
                for k, v in camera_json.items()
            })

        color_ims = []
        depth_ims = []
        mask_ims = []
        if frames is None:
            inds = list(range(len(cameras)))
        elif isinstance(frames, int):
            inds = [frames]
        else:
            inds = frames

        cameras = cameras[inds]

        for i in inds:
            color_ims.append(
                imageio.imread(path / f"{i:04d}.color.png").astype(np.float32)
                / 255.0)
            depth_ims.append(
                imageio.imread(path / f"{i:04d}.depth.png").astype(np.float32)
                / 1000.0)
            mask_ims.append(
                imageio.imread(path / f"{i:04d}.mask.png").astype(np.bool))

        color = torch.stack(
            [torch.tensor(x).permute(2, 0, 1) for x in color_ims], dim=0)
        depth = torch.stack([torch.tensor(x).unsqueeze(0) for x in depth_ims],
                            dim=0)
        mask = torch.stack(
            [torch.tensor(x).float().unsqueeze(0) for x in mask_ims], dim=0)

        return cls(color, depth, mask, cameras, **meta)
Example #14
0
    def _estimate(self, z_obj, target_obs, **kwargs):
        camera_init = self.initial_pose(target_obs)
        camera = pu.sample_cameras_with_estimate(self.num_samples, camera_init).to(
            self.device)
        error = torch.full((self.num_samples,), fill_value=100.0, device=self.device)
        ranking = []

        temp_weight = 1.0 / camera_init.translation[:, -1].mean().item()
        temp_sched = ExponentialScheduler(temp_weight * 0.1,
                                          temp_weight * 0.005, num_steps=self.num_iters)
        logger.info("simulated annealing",
                    temp_weight=temp_weight,
                    temp_sched_range=[temp_sched.initial_value, temp_sched.final_value],
                    n_iters=self.num_iters,
                    n_samples=self.num_samples)

        target_obs = target_obs.to(self.device)

        camera_history = []
        pbar = utils.trange(self.num_iters)
        for step in pbar:
            temperature = temp_sched.get(step)
            camera, error, num_accepted = self._refine_pose(z_obj,
                                                            camera.clone(),
                                                            error.clone(),
                                                            target_obs=target_obs,
                                                            temperature=temperature)
            delta = self._track_best_items(ranking, step, camera, error)
            if delta > 0:
                camera_history.append((error, camera.clone().cpu()))

            pbar.set_description(
                f"E={ranking[0][1]:.05f}, "
                f"T={temperature:.04f}, "
                f"N={num_accepted}/{self.num_samples}")

        cameras = Camera.cat([c for c, e, step in ranking])
        if self.return_camera_history:
            return cameras, camera_history
        else:
            return cameras
Example #15
0
    def _refine_pose(self, z_obj, target_obs, prev_gmm, gmm, num_elites, camera_init):
        # Sample from blended distribution and then set current distribution to
        # new distribution.
        if prev_gmm is not None:
            sample_gmm = self._combined_gmm(prev_gmm, gmm, self.learning_rate)
        else:
            sample_gmm = gmm

        num_samples = self.num_samples // 4 if self.sample_flipped else self.num_samples
        params = self._sample_poses(sample_gmm, num_samples)
        cameras = self._params_to_camera(params, camera_init=camera_init, device=self.device)

        if self.sample_flipped:
            cameras = Camera.cat([
                cameras,
                pu.flip_camera(cameras, axis=(0.0, 0.0, 1.0)),
                pu.flip_camera(cameras, axis=(0.0, 1.0, 0.0)),
                pu.flip_camera(cameras, axis=(1.0, 0.0, 0.0)),
            ])

        if self.loss_weights.get('latent', 0.0) > 0.0:
            with torch.no_grad():
                z_target_latent = self.model.compute_latent_code(target_obs, cameras[0])
        else:
            z_target_latent = None

        z_pred_depth, z_pred_mask_logits, z_pred_latent, z_camera = self._render_observation(z_obj, cameras)
        loss_dict = self.loss_func(target_obs, z_pred_depth, z_pred_mask_logits, z_camera,
                                   z_pred_latent=z_pred_latent,
                                   z_target_latent=z_target_latent)
        loss = sum(weigh_losses(loss_dict, self.loss_weights).values())
        sorted_inds = torch.argsort(loss)
        elite_inds = sorted_inds[:num_elites]
        if self.verbose:
            logger.info('pose error', **{k: v[elite_inds[0]].item() for k, v in loss_dict.items()})

        elite_losses = loss[elite_inds]
        elite_cameras = cameras[elite_inds]

        return elite_cameras, elite_losses
Example #16
0
def sample_cameras_with_estimate(n,
                                 camera_est,
                                 translation_std=0.0,
                                 hemisphere=False,
                                 upright=False) -> Camera:
    device = camera_est.device
    intrinsic = camera_est.intrinsic.expand(n, -1, -1)
    translation = camera_est.translation.expand(n, -1)
    translation = translation + torch.randn_like(translation) * translation_std
    # quaternion = three.orientation.disk_sample_quats(n, min_angle=min_angle)
    # quaternion = three.orientation.evenly_distributed_quats(n)
    quaternion = three.orientation.evenly_distributed_quats(
        n, hemisphere=hemisphere, upright=upright)
    extrinsic = three.to_extrinsic_matrix(translation.cpu(),
                                          quaternion).to(device)
    viewport = camera_est.viewport.expand(n, -1)

    return Camera(intrinsic,
                  extrinsic,
                  camera_est.z_span,
                  width=camera_est.width,
                  height=camera_est.height,
                  viewport=viewport)
Example #17
0
 def collate(cls, observations):
     color = torch.cat([o.color for o in observations], dim=0)
     depth = torch.cat([o.depth for o in observations], dim=0)
     mask = torch.cat([o.mask for o in observations], dim=0)
     camera = Camera.cat([o.camera for o in observations])
     return cls(color, depth, mask, camera, **observations[0].meta)
Example #18
0
    def run_iteration(self, batch, train, is_step):
        self.mark_time()
        # Update depth criterion k if applicable.
        if 'hard_' in self.g_depth_recon_loss_type:
            self._g_depth_recon_criterion.k = int(
                self._g_depth_recon_k_scheduler.get(self.epoch))

        batch = process_batch(batch, self.cube_size, self.camera_dist,
                              self._sculptor.in_size, self.device,
                              self.random_orientation)

        if self.reconstruct_input:
            recon_camera = Camera.vcat(
                (batch['in_gt']['camera'], batch['out_gt']['camera']),
                batch_size=self.batch_size)
            recon_mask = torch.cat(
                (batch['in_gt']['mask'], batch['out_gt']['mask']), dim=1)
            recon_image = torch.cat(
                (batch['in_gt']['image'], batch['out_gt']['image']), dim=1)
            recon_depth = torch.cat(
                (batch['in_gt']['depth'], batch['out_gt']['depth']), dim=1)
        else:
            recon_camera = batch['out_gt']['camera']
            recon_mask = batch['out_gt']['mask']
            recon_image = batch['out_gt']['image']
            recon_depth = batch['out_gt']['depth']

        if not self.color_random_background or self.crop_random_background:
            batch['in']['image'] = batch['in']['image'] * batch['in']['mask']

        if not self.depth_random_background or self.crop_random_background:
            batch['in']['depth'] = mask_normalized_depth(
                batch['in']['depth'], batch['in']['mask'])

        depth_in = None
        if self.generator_input_depth:
            depth_noise = self._depth_noise_dist.sample(
                batch['in']['depth'].size()).to(self.device)
            depth_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1)

        data_process_time = self.mark_time()

        with autocast():
            # Evaluate generator.
            z_obj, z_extra = self._sculptor.encode(
                self._fuser,
                camera=batch['in']['camera'],
                color=batch['in']['image'],
                depth=depth_in,
                mask=batch['in']['mask'],
                data_parallel=self.data_parallel)
            fake_image, fake_depth, fake_mask, fake_mask_logits, fake_vox_depth = \
                self._run_photographer(z_obj, recon_camera, recon_mask)

            if 'blend_weights' in z_extra:
                z_weights = z_extra['blend_weights']
            else:
                z_weights = None

            # Train discriminator.
            if self._discriminator:
                d_real, d_fake_d, d_fake_g = self._run_discriminator(
                    fake_image, fake_depth, fake_mask, recon_image,
                    recon_depth, recon_mask)
                loss_d_real = multiscale_lsgan_loss(d_real, 1)
                loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0)
                loss_d = loss_d_real + loss_d_fake
                loss_g_gan = multiscale_lsgan_loss(d_fake_g, 1)

                if train:
                    loss_d.backward()
                    if is_step:
                        self._optimizers['discriminator'].step()

                self.plotter.put_scalar('loss/discriminator/real', loss_d_real)
                self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake)
                self.plotter.put_scalar('loss/discriminator/total', loss_d)
            else:
                loss_g_gan = torch.tensor(0.0, device=self.device)

            # Train generator.
            if self.predict_color:
                loss_g_color_recon = reduce_loss(
                    self._g_color_recon_criterion(fake_image, recon_image))
            else:
                loss_g_color_recon = torch.tensor(0.0, device=self.device)

            if self.predict_depth or self.use_occlusion_depth:
                loss_g_depth_recon = reduce_loss(
                    self._g_depth_recon_criterion(fake_depth, recon_depth))
            else:
                loss_g_depth_recon = torch.tensor(0.0, device=self.device)

            if self.predict_mask:
                if self.g_mask_recon_loss_type == 'binary_cross_entropy':
                    y_mask = fake_mask_logits
                else:
                    y_mask = fake_mask
                loss_g_mask_recon = reduce_loss(
                    self._g_mask_recon_criterion(y_mask, recon_mask))
                loss_g_mask_beta = beta_prior_loss(
                    fake_mask,
                    alpha=self.g_mask_beta_loss_param,
                    beta=self.g_mask_beta_loss_param)
            else:
                loss_g_mask_recon = torch.tensor(0.0, device=self.device)
                loss_g_mask_beta = torch.tensor(0.0, device=self.device)

            loss_g = (self.g_gan_loss_weight * loss_g_gan +
                      self.g_color_recon_loss_weight * loss_g_color_recon +
                      self.g_depth_recon_loss_weight * loss_g_depth_recon +
                      self.g_mask_recon_loss_weight * loss_g_mask_recon +
                      self.g_mask_beta_loss_weight *
                      loss_g_mask_beta) / self.batch_groups

        if train:
            if self.kwargs.get('use_amp', False):
                self._scaler.scale(loss_g).backward()
            else:
                loss_g.backward()

            if is_step:
                if self.kwargs.get('use_amp', False):
                    self._scaler.step(self._optimizers['generator'])
                    self._scaler.update()
                else:
                    self._optimizers['generator'].step()

        with torch.no_grad():
            if self.predict_depth:
                self.plotter.put_scalar('error/depth/l1',
                                        F.l1_loss(fake_depth, recon_depth))
            if self.reconstruct_input:
                self.plotter.put_scalar(
                    'error/depth/input_l1',
                    F.l1_loss(fake_depth[:, :self.num_input_views],
                              batch['in_gt']['depth']))
                self.plotter.put_scalar(
                    'error/depth/output_l1',
                    F.l1_loss(fake_depth[:, self.num_input_views:],
                              batch['out_gt']['depth']))
            if self.predict_mask:
                self.plotter.put_scalar(
                    'error/mask/cross_entropy',
                    F.binary_cross_entropy_with_logits(fake_mask_logits,
                                                       recon_mask))
                self.plotter.put_scalar('error/mask/l1',
                                        F.l1_loss(fake_mask, recon_mask))

        compute_time = self.mark_time()

        self.plotter.put_scalar('loss/generator/gan', loss_g_gan)
        self.plotter.put_scalar('loss/generator/recon/color',
                                loss_g_color_recon)
        self.plotter.put_scalar('loss/generator/recon/depth',
                                loss_g_depth_recon)
        self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon)
        self.plotter.put_scalar('loss/generator/recon/mask_beta',
                                loss_g_mask_beta)
        self.plotter.put_scalar('loss/generator/total', loss_g)

        self.plotter.put_scalar('params/input_noise_weight',
                                self.input_noise_weight)
        if hasattr(self._g_depth_recon_criterion, 'k'):
            self.plotter.put_scalar('params/depth_loss_k',
                                    self._g_depth_recon_criterion.k)

        self.plotter.put_scalar('time/data_process', data_process_time)
        self.plotter.put_scalar('time/compute', compute_time)
        plot_scalar_time = self.mark_time()
        self.plotter.put_scalar('time/plot/scalars', plot_scalar_time)

        if self.plotter.is_it_time_yet('histogram'):
            if self.predict_color:
                self.plotter.put_histogram('image_fake', fake_image)
                self.plotter.put_histogram('image_real', recon_image)
            if self.predict_mask:
                self.plotter.put_histogram('mask_fake', fake_mask)
            self.plotter.put_histogram('z_obj', z_obj)
            if z_weights is not None:
                self.plotter.put_histogram('z_weights', z_weights)
        plot_histogram_time = self.mark_time()
        self.plotter.put_scalar('time/plot/histogram', plot_histogram_time)

        if self.plotter.is_it_time_yet('show'):
            self.plotter.put_image(
                'inputs',
                viz.make_grid([
                    gan_denormalize(batch['in']['image']),
                    viz.colorize_depth(batch['in']['depth'])
                    if self.generator_input_depth else None,
                    viz.colorize_tensor(batch['in']['mask'])
                    if self.generator_input_mask else None,
                ],
                              row_size=4,
                              stride=2,
                              output_size=64))
            with torch.no_grad():
                self.plotter.put_image(
                    'reconstruction',
                    viz.make_grid([
                        gan_denormalize(recon_image),
                        gan_denormalize(fake_image) if
                        (fake_image is not None) else None,
                        viz.colorize_depth(recon_depth),
                        viz.colorize_depth(fake_depth) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(
                            (recon_depth.cpu() - fake_depth.cpu()).abs()) if
                        (fake_depth is not None) else None,
                        viz.colorize_tensor(recon_mask),
                        viz.colorize_tensor(fake_mask) if
                        (fake_mask is not None) else None,
                        viz.colorize_tensor(
                            (recon_mask.cpu() - fake_mask.cpu()).abs()) if
                        (fake_mask is not None) else None,
                    ],
                                  stride=8))
        plot_images_time = self.mark_time()
        self.plotter.put_scalar('time/plot/images', plot_images_time)
Example #19
0
    def _optimize_camera(self, z_obj, target_obs, cameras, iters, ranking):
        optimizers = []
        schedulers = []
        param_cameras = [pu.parameterize_camera(camera, optimize_viewport=True) for camera in
                         cameras]
        for camera in param_cameras:
            parameters = [camera.log_quaternion, camera.translation, camera.viewport]
            optimizer = self.get_optimizer(self.optimizer, parameters, lr=self.learning_rate)
            optimizers.append(optimizer)
            schedulers.append(
                optim.lr_scheduler.ReduceLROnPlateau(
                    optimizer,
                    patience=self.lr_reduce_patience,
                    threshold=self.lr_reduce_threshold,
                    factor=self.lr_reduce_factor,
                    verbose=self.verbose))

        pbar = utils.trange(iters)
        stat_history = {}
        converge_count = 0
        camera_history = []

        for step in pbar:
            for optimizer in optimizers:
                optimizer.zero_grad()
            cameras = Camera.cat(param_cameras)
            if self.loss_weights.get('latent', 0.0) > 0.0:
                with torch.no_grad():
                    z_target_latent = self.model.compute_latent_code(target_obs, cameras)
            else:
                z_target_latent = None
            z_depth, z_mask, z_mask_logits, z_pred_latent = self._render_observation(z_obj, cameras)
            optim_weights = copy.copy(self.loss_weights)
            optim_weights.update({k: v.get(step) for k, v in self.loss_schedules.items()})

            loss_dict = self.loss_func(target_obs, z_depth, z_mask_logits, cameras,
                                       z_pred_latent=z_pred_latent, z_target_latent=z_target_latent)
            optim_loss = sum(weigh_losses(loss_dict, optim_weights).values())
            optim_loss.mean().backward()
            rank_loss = sum(weigh_losses(loss_dict, self.loss_weights).values())

            best_idx = torch.argmin(rank_loss)
            detached_cameras = pu.deparameterize_camera(cameras.uncrop()).clone()
            angle_dists = three.quaternion.angular_distance(
                detached_cameras.quaternion, target_obs.camera.quaternion).squeeze()
            translation_dists = torch.norm(detached_cameras.translation
                                           - target_obs.camera.translation, dim=1).squeeze()
            if self.return_camera_history:
                camera_history.append((rank_loss.detach().cpu(), detached_cameras.cpu()))

            # Save best cameras in ranking list.
            delta = self._track_best_items(ranking, step,
                                           items=detached_cameras.cpu(),
                                           loss=rank_loss)

            pbar.set_description(f"idx={best_idx}, loss={rank_loss[best_idx].item():.04f}"
                                 f", depth={loss_dict['depth'][best_idx].item():.04f}"
                                 f", ov_depth={loss_dict['ov_depth'][best_idx].item():.04f}"
                                 f", mask={loss_dict['mask'][best_idx].item():.04f}"
                                 f", iou={loss_dict['iou'][best_idx].item():.04f}"
                                 f", latent={loss_dict.get('latent', [0.0]*len(cameras))[best_idx]:.04f}"
                                 f", converge={converge_count}"
                                 f", angle={angle_dists[best_idx].item() / math.pi * 180:.02f}°"
                                 f", trans={translation_dists[best_idx].item():.04f}"
                                 f"")

            if self.track_stats:
                self._record_stat_dict(stat_history, {
                    **{f'{k}_loss': v.detach().cpu() for k, v in loss_dict.items()},
                    **{f'{k}_weight': v for k, v in optim_weights.items()},
                    'delta': delta,
                    'converge_count': converge_count,
                    'angle_dist': angle_dists.cpu(),
                    'trans_dist': translation_dists.cpu(),
                    'optim_loss': optim_loss.detach().cpu(),
                    'rank_loss': rank_loss.detach().cpu(),
                    # 'translation_grad': (cameras.translation.grad
                    #                      if cameras.translation.grad is not None
                    #                      else torch.zeros_like(cameras.translation)),
                    # 'rotation_grad': (cameras.log_quaternion.grad
                    #                   if cameras.log_quaternion.grad is not None
                    #                   else torch.zeros_like(cameras.log_quaternion)),
                    # 'viewport_grad': cameras.viewport.grad,
                })

            for i, (optimizer, scheduler) in enumerate(zip(optimizers, schedulers)):
                optimizer.step()
                scheduler.step(rank_loss[i])

            if delta < self.converge_threshold:
                converge_count += 1
            elif delta > self.converge_threshold:
                converge_count = 0

            if converge_count >= self.converge_patience:
                logger.info("convergence threshold reached", step=step, delta=delta,
                            count=converge_count)
                pbar.close()
                break

        return stat_history, camera_history