def _estimate(self, z_obj, target_obs, **kwargs): if 'camera' in kwargs: camera = kwargs['camera'] else: camera = self.initial_pose(target_obs) camera = pu.sample_cameras_with_estimate( n=self.num_samples, camera_est=camera) target_obs = target_obs.to(self.device) # Mask parameters. # Optimize the 'zoomed' camera. camera = camera.zoom(None, self.model.input_size, self.model.camera_dist).to(self.device) ranking = [] stat_history, camera_history = self._optimize_camera( z_obj, target_obs, camera, iters=self.num_iters, ranking=ranking) logger.info('best camera', step=ranking[0][2], loss=ranking[0][1]) best_cameras = Camera.cat([c for c, loss, step in ranking]) if self.track_stats and self.return_camera_history: return best_cameras, stat_history, camera_history elif self.track_stats: return best_cameras, stat_history elif self.return_camera_history: return best_cameras, camera_history else: return best_cameras
def parse_poserbpf_cameras(seq_path): mat_paths = sorted(seq_path.glob('*.mat')) # print(mat_paths) inds = [int(x.name.split('.')[0]) for x in mat_paths] cameras = [] for mat in mat_paths: cameras.append(load_poserbpf_camera(mat)) return Camera.cat(cameras)
def load_poserbpf_camera(mat_path, key='poses'): mat = loadmat(mat_path) intrinsic = torch.tensor(mat['intrinsic_matrix']).float() pose = torch.tensor(mat[key]).squeeze().float() quat = pose[:4] translation = pose[4:] extrinsic = three.to_extrinsic_matrix(translation, quat) camera = Camera(intrinsic=intrinsic, extrinsic=extrinsic) return camera
def item_to_obs(item): height, width = item['color'].shape[-2:] return latentfusion.observation.Observation(item['color'].unsqueeze(0), item['depth'].unsqueeze(0).unsqueeze(0), item['mask'].unsqueeze(0).unsqueeze(0).float(), Camera(intrinsic=item['intrinsic'], extrinsic=item['extrinsic'], width=width, height=height))
def estimate_initial_pose(depth, mask, intrinsic, width, height) -> Camera: """Estimate the initial pose based on depth.""" translation = torch.stack(estimate_translation(depth, mask, intrinsic), dim=-1) rotation = three.quaternion.identity(intrinsic.shape[0], intrinsic.device) extrinsic = three.to_extrinsic_matrix(translation, rotation) camera = Camera(intrinsic, extrinsic, height=height, width=width) return camera
def _process_batch(batch, rotation, cube_size, camera_dist, input_size, device, is_gt): # Collapse viewpoint dimension to batch dimension: # (B, V, C, H, W) => (B*V, C, H, W) batch_size = batch['mask'].shape[0] extrinsic = bv2b(batch['extrinsic'].to(device)) intrinsic = bv2b(batch['intrinsic'].to(device)) mask = bv2b(batch['mask'].unsqueeze(2).float().to(device)) image = bv2b(gan_normalize(batch['render'].to(device))) if 'depth' in batch: depth = bv2b(batch['depth'].unsqueeze(2).to(device)) else: depth = None # Project image features onto canonical volume. camera = Camera(intrinsic, extrinsic, z_span=cube_size / 2.0, height=image.size(2), width=image.size(3)).to(device) if rotation is not None: camera.rotate(rotation.expand(camera.length, -1)) # translation = three.uniform(3, -cube_size/16, cube_size/16).view(1, 3).expand(camera.length, -1).to(device) # camera.translate(translation) _zoom = functools.partial(camera.zoom, target_size=input_size, target_dist=camera_dist) out = dict() # Zoom camera to canonical distance and size. out['image'], out['camera'] = _zoom(image, scale_mode='bilinear') out['mask'] = _zoom(mask, scale_mode='nearest')[0] if depth is not None: out['depth'] = camera.normalize_depth( _zoom(depth, scale_mode='nearest')[0]) if is_gt: out['image'] = out['image'] * out['mask'] out['depth'] = mask_normalized_depth(out['depth'], out['mask']) for k in {'image', 'depth', 'mask'}: out[k] = b2bv(out[k], batch_size=batch_size) return out
def _estimate(self, z_obj, target_obs, **kwargs): if kwargs.get('cameras', None): cameras = kwargs['cameras'] camera_init = kwargs['cameras'][0] else: camera_init = self.initial_pose(target_obs) cameras = pu.sample_cameras_with_estimate( n=self.num_gmm_components * self.num_samples, camera_est=camera_init, upright=self.init_upright, hemisphere=self.init_hemisphere) gmm = self._create_gmm(self._camera_to_params(cameras)) target_obs = target_obs.to(self.device) camera_history = [] prev_gmm = None ranking = [] pbar = utils.trange(self.num_iters) for step in pbar: # Refine pose. _num_elites = int(self.elite_sched.get(step)) cameras, losses = self._refine_pose(z_obj, target_obs, prev_gmm, gmm, num_elites=_num_elites, camera_init=camera_init) prev_gmm = gmm gmm = self._create_gmm(self._camera_to_params(cameras).cpu()) delta = self._track_best_items(ranking, step, cameras, losses) if delta > 0: camera_history.append((losses, Camera.cat([c for c, e, step in ranking]))) pbar.set_description(f"best_error={ranking[0][1]:.05f}, num_elite={_num_elites}") # gmm_camera = self._params_to_camera(torch.tensor(gmm.means_, dtype=torch.float32), # camera_init=camera_init) logger.info('best camera', step=ranking[0][2], loss=ranking[0][1]) cameras = Camera.cat([c for c, e, step in ranking]) if self.return_camera_history: return cameras, camera_history else: return cameras
def from_dict(cls, d): height, width = d['color'].shape[-2:] camera = Camera(d['intrinsic'], d['extrinsic'], width=width, height=height) return cls( d['color'], d['depth'].unsqueeze(-3), # Create channel dimension. d['mask'].unsqueeze(-3).float(), camera)
def render_observation(renderer, scene): color, depth, mask = renderer.render(scene) camera = Camera(scene.intrinsic, scene.extrinsic, width=renderer.width, height=renderer.height) return Observation(color.permute(2, 0, 1).unsqueeze(0), depth.unsqueeze(0).unsqueeze(0), mask.unsqueeze(0).unsqueeze(0), camera, object_scale=scene.obj.scale)
def zoom(self, target_dist, target_size, camera: Camera = None): if camera is None: camera = self.camera color, new_camera = camera.zoom(self.color, target_size, target_dist, scale_mode='bilinear') depth, _ = camera.zoom(self.depth, target_size, target_dist, scale_mode='nearest') mask, _ = camera.zoom(self.mask, target_size, target_dist, scale_mode='nearest') kwargs = copy.deepcopy(self.meta) kwargs['is_zoomed'] = True return Observation(color, depth, mask, new_camera, **kwargs)
def _render_reprojections(self, batch): recon_camera = Camera.vcat( (batch['in_gt']['camera'], batch['out_gt']['camera']), batch_size=self.batch_size) if self._recon_params['generator_input_depth']: depth_noise = self._depth_noise_dist.sample( batch['in']['depth'].size()).to(self.device) depth_real_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1) else: depth_real_in = None # Create IBR images. with torch.set_grad_enabled(self.train_recon): z_obj, z_extra = self._sculptor.encode( self._fuser, camera=batch['in']['camera'], color=batch['in']['image'], depth=depth_real_in, mask=batch['in']['mask'], data_parallel=self.data_parallel) fake, _, _ = self._photographer.decode( z_obj, recon_camera, data_parallel=self.data_parallel) # Reshape things to BVCHW. sections = (self.num_input_views, self.num_output_views) depth_fake_in, depth_fake_out = torch.split(fake['depth'], sections, dim=1) mask_fake_in, mask_fake_out = torch.split(fake['mask'], sections, dim=1) image_reproj, depth_reproj, cam_dists_r, cam_dists_t = ibr.reproject_views_batch( image_in=batch['in']['image'], depth_in=depth_fake_in, depth_out=depth_fake_out, camera_in=batch['in']['camera'], camera_out=batch['out_gt']['camera']) image_reproj = image_reproj * mask_fake_out.unsqueeze(2) depth_reproj = (depth_reproj + 1.0) * mask_fake_out.unsqueeze(2) - 1.0 return ( image_reproj, depth_reproj, mask_fake_out.contiguous(), depth_fake_out.contiguous(), cam_dists_r, cam_dists_t, )
def _params_to_camera(cls, params, camera_init, device='cpu'): if len(params.shape) == 1: params = params.unsqueeze(0) intrinsic = camera_init.intrinsic.expand(params.shape[0], -1, -1).to(device) translations = params[:, :3].to(device) log_quaternions = params[:, 3:].to(device) cameras = Camera(intrinsic=intrinsic, extrinsic=None, translation=translations, log_quaternion=log_quaternions, width=camera_init.width, height=camera_init.height, z_span=camera_init.z_span).to(device) return cameras
def load(cls, path, frames=None) -> 'Observation': if isinstance(path, str): path = Path(path) with open(path / f'cameras.json', 'r') as f: camera_json = json.load(f) if 'meta' in camera_json: meta = camera_json.pop('meta') else: meta = {} cameras = Camera( **{ k: torch.tensor(v, dtype=torch.float32) if isinstance(v, list ) else v for k, v in camera_json.items() }) color_ims = [] depth_ims = [] mask_ims = [] if frames is None: inds = list(range(len(cameras))) elif isinstance(frames, int): inds = [frames] else: inds = frames cameras = cameras[inds] for i in inds: color_ims.append( imageio.imread(path / f"{i:04d}.color.png").astype(np.float32) / 255.0) depth_ims.append( imageio.imread(path / f"{i:04d}.depth.png").astype(np.float32) / 1000.0) mask_ims.append( imageio.imread(path / f"{i:04d}.mask.png").astype(np.bool)) color = torch.stack( [torch.tensor(x).permute(2, 0, 1) for x in color_ims], dim=0) depth = torch.stack([torch.tensor(x).unsqueeze(0) for x in depth_ims], dim=0) mask = torch.stack( [torch.tensor(x).float().unsqueeze(0) for x in mask_ims], dim=0) return cls(color, depth, mask, cameras, **meta)
def _estimate(self, z_obj, target_obs, **kwargs): camera_init = self.initial_pose(target_obs) camera = pu.sample_cameras_with_estimate(self.num_samples, camera_init).to( self.device) error = torch.full((self.num_samples,), fill_value=100.0, device=self.device) ranking = [] temp_weight = 1.0 / camera_init.translation[:, -1].mean().item() temp_sched = ExponentialScheduler(temp_weight * 0.1, temp_weight * 0.005, num_steps=self.num_iters) logger.info("simulated annealing", temp_weight=temp_weight, temp_sched_range=[temp_sched.initial_value, temp_sched.final_value], n_iters=self.num_iters, n_samples=self.num_samples) target_obs = target_obs.to(self.device) camera_history = [] pbar = utils.trange(self.num_iters) for step in pbar: temperature = temp_sched.get(step) camera, error, num_accepted = self._refine_pose(z_obj, camera.clone(), error.clone(), target_obs=target_obs, temperature=temperature) delta = self._track_best_items(ranking, step, camera, error) if delta > 0: camera_history.append((error, camera.clone().cpu())) pbar.set_description( f"E={ranking[0][1]:.05f}, " f"T={temperature:.04f}, " f"N={num_accepted}/{self.num_samples}") cameras = Camera.cat([c for c, e, step in ranking]) if self.return_camera_history: return cameras, camera_history else: return cameras
def _refine_pose(self, z_obj, target_obs, prev_gmm, gmm, num_elites, camera_init): # Sample from blended distribution and then set current distribution to # new distribution. if prev_gmm is not None: sample_gmm = self._combined_gmm(prev_gmm, gmm, self.learning_rate) else: sample_gmm = gmm num_samples = self.num_samples // 4 if self.sample_flipped else self.num_samples params = self._sample_poses(sample_gmm, num_samples) cameras = self._params_to_camera(params, camera_init=camera_init, device=self.device) if self.sample_flipped: cameras = Camera.cat([ cameras, pu.flip_camera(cameras, axis=(0.0, 0.0, 1.0)), pu.flip_camera(cameras, axis=(0.0, 1.0, 0.0)), pu.flip_camera(cameras, axis=(1.0, 0.0, 0.0)), ]) if self.loss_weights.get('latent', 0.0) > 0.0: with torch.no_grad(): z_target_latent = self.model.compute_latent_code(target_obs, cameras[0]) else: z_target_latent = None z_pred_depth, z_pred_mask_logits, z_pred_latent, z_camera = self._render_observation(z_obj, cameras) loss_dict = self.loss_func(target_obs, z_pred_depth, z_pred_mask_logits, z_camera, z_pred_latent=z_pred_latent, z_target_latent=z_target_latent) loss = sum(weigh_losses(loss_dict, self.loss_weights).values()) sorted_inds = torch.argsort(loss) elite_inds = sorted_inds[:num_elites] if self.verbose: logger.info('pose error', **{k: v[elite_inds[0]].item() for k, v in loss_dict.items()}) elite_losses = loss[elite_inds] elite_cameras = cameras[elite_inds] return elite_cameras, elite_losses
def sample_cameras_with_estimate(n, camera_est, translation_std=0.0, hemisphere=False, upright=False) -> Camera: device = camera_est.device intrinsic = camera_est.intrinsic.expand(n, -1, -1) translation = camera_est.translation.expand(n, -1) translation = translation + torch.randn_like(translation) * translation_std # quaternion = three.orientation.disk_sample_quats(n, min_angle=min_angle) # quaternion = three.orientation.evenly_distributed_quats(n) quaternion = three.orientation.evenly_distributed_quats( n, hemisphere=hemisphere, upright=upright) extrinsic = three.to_extrinsic_matrix(translation.cpu(), quaternion).to(device) viewport = camera_est.viewport.expand(n, -1) return Camera(intrinsic, extrinsic, camera_est.z_span, width=camera_est.width, height=camera_est.height, viewport=viewport)
def collate(cls, observations): color = torch.cat([o.color for o in observations], dim=0) depth = torch.cat([o.depth for o in observations], dim=0) mask = torch.cat([o.mask for o in observations], dim=0) camera = Camera.cat([o.camera for o in observations]) return cls(color, depth, mask, camera, **observations[0].meta)
def run_iteration(self, batch, train, is_step): self.mark_time() # Update depth criterion k if applicable. if 'hard_' in self.g_depth_recon_loss_type: self._g_depth_recon_criterion.k = int( self._g_depth_recon_k_scheduler.get(self.epoch)) batch = process_batch(batch, self.cube_size, self.camera_dist, self._sculptor.in_size, self.device, self.random_orientation) if self.reconstruct_input: recon_camera = Camera.vcat( (batch['in_gt']['camera'], batch['out_gt']['camera']), batch_size=self.batch_size) recon_mask = torch.cat( (batch['in_gt']['mask'], batch['out_gt']['mask']), dim=1) recon_image = torch.cat( (batch['in_gt']['image'], batch['out_gt']['image']), dim=1) recon_depth = torch.cat( (batch['in_gt']['depth'], batch['out_gt']['depth']), dim=1) else: recon_camera = batch['out_gt']['camera'] recon_mask = batch['out_gt']['mask'] recon_image = batch['out_gt']['image'] recon_depth = batch['out_gt']['depth'] if not self.color_random_background or self.crop_random_background: batch['in']['image'] = batch['in']['image'] * batch['in']['mask'] if not self.depth_random_background or self.crop_random_background: batch['in']['depth'] = mask_normalized_depth( batch['in']['depth'], batch['in']['mask']) depth_in = None if self.generator_input_depth: depth_noise = self._depth_noise_dist.sample( batch['in']['depth'].size()).to(self.device) depth_in = (batch['in']['depth'] + depth_noise).clamp(-1, 1) data_process_time = self.mark_time() with autocast(): # Evaluate generator. z_obj, z_extra = self._sculptor.encode( self._fuser, camera=batch['in']['camera'], color=batch['in']['image'], depth=depth_in, mask=batch['in']['mask'], data_parallel=self.data_parallel) fake_image, fake_depth, fake_mask, fake_mask_logits, fake_vox_depth = \ self._run_photographer(z_obj, recon_camera, recon_mask) if 'blend_weights' in z_extra: z_weights = z_extra['blend_weights'] else: z_weights = None # Train discriminator. if self._discriminator: d_real, d_fake_d, d_fake_g = self._run_discriminator( fake_image, fake_depth, fake_mask, recon_image, recon_depth, recon_mask) loss_d_real = multiscale_lsgan_loss(d_real, 1) loss_d_fake = multiscale_lsgan_loss(d_fake_d, 0) loss_d = loss_d_real + loss_d_fake loss_g_gan = multiscale_lsgan_loss(d_fake_g, 1) if train: loss_d.backward() if is_step: self._optimizers['discriminator'].step() self.plotter.put_scalar('loss/discriminator/real', loss_d_real) self.plotter.put_scalar('loss/discriminator/fake', loss_d_fake) self.plotter.put_scalar('loss/discriminator/total', loss_d) else: loss_g_gan = torch.tensor(0.0, device=self.device) # Train generator. if self.predict_color: loss_g_color_recon = reduce_loss( self._g_color_recon_criterion(fake_image, recon_image)) else: loss_g_color_recon = torch.tensor(0.0, device=self.device) if self.predict_depth or self.use_occlusion_depth: loss_g_depth_recon = reduce_loss( self._g_depth_recon_criterion(fake_depth, recon_depth)) else: loss_g_depth_recon = torch.tensor(0.0, device=self.device) if self.predict_mask: if self.g_mask_recon_loss_type == 'binary_cross_entropy': y_mask = fake_mask_logits else: y_mask = fake_mask loss_g_mask_recon = reduce_loss( self._g_mask_recon_criterion(y_mask, recon_mask)) loss_g_mask_beta = beta_prior_loss( fake_mask, alpha=self.g_mask_beta_loss_param, beta=self.g_mask_beta_loss_param) else: loss_g_mask_recon = torch.tensor(0.0, device=self.device) loss_g_mask_beta = torch.tensor(0.0, device=self.device) loss_g = (self.g_gan_loss_weight * loss_g_gan + self.g_color_recon_loss_weight * loss_g_color_recon + self.g_depth_recon_loss_weight * loss_g_depth_recon + self.g_mask_recon_loss_weight * loss_g_mask_recon + self.g_mask_beta_loss_weight * loss_g_mask_beta) / self.batch_groups if train: if self.kwargs.get('use_amp', False): self._scaler.scale(loss_g).backward() else: loss_g.backward() if is_step: if self.kwargs.get('use_amp', False): self._scaler.step(self._optimizers['generator']) self._scaler.update() else: self._optimizers['generator'].step() with torch.no_grad(): if self.predict_depth: self.plotter.put_scalar('error/depth/l1', F.l1_loss(fake_depth, recon_depth)) if self.reconstruct_input: self.plotter.put_scalar( 'error/depth/input_l1', F.l1_loss(fake_depth[:, :self.num_input_views], batch['in_gt']['depth'])) self.plotter.put_scalar( 'error/depth/output_l1', F.l1_loss(fake_depth[:, self.num_input_views:], batch['out_gt']['depth'])) if self.predict_mask: self.plotter.put_scalar( 'error/mask/cross_entropy', F.binary_cross_entropy_with_logits(fake_mask_logits, recon_mask)) self.plotter.put_scalar('error/mask/l1', F.l1_loss(fake_mask, recon_mask)) compute_time = self.mark_time() self.plotter.put_scalar('loss/generator/gan', loss_g_gan) self.plotter.put_scalar('loss/generator/recon/color', loss_g_color_recon) self.plotter.put_scalar('loss/generator/recon/depth', loss_g_depth_recon) self.plotter.put_scalar('loss/generator/recon/mask', loss_g_mask_recon) self.plotter.put_scalar('loss/generator/recon/mask_beta', loss_g_mask_beta) self.plotter.put_scalar('loss/generator/total', loss_g) self.plotter.put_scalar('params/input_noise_weight', self.input_noise_weight) if hasattr(self._g_depth_recon_criterion, 'k'): self.plotter.put_scalar('params/depth_loss_k', self._g_depth_recon_criterion.k) self.plotter.put_scalar('time/data_process', data_process_time) self.plotter.put_scalar('time/compute', compute_time) plot_scalar_time = self.mark_time() self.plotter.put_scalar('time/plot/scalars', plot_scalar_time) if self.plotter.is_it_time_yet('histogram'): if self.predict_color: self.plotter.put_histogram('image_fake', fake_image) self.plotter.put_histogram('image_real', recon_image) if self.predict_mask: self.plotter.put_histogram('mask_fake', fake_mask) self.plotter.put_histogram('z_obj', z_obj) if z_weights is not None: self.plotter.put_histogram('z_weights', z_weights) plot_histogram_time = self.mark_time() self.plotter.put_scalar('time/plot/histogram', plot_histogram_time) if self.plotter.is_it_time_yet('show'): self.plotter.put_image( 'inputs', viz.make_grid([ gan_denormalize(batch['in']['image']), viz.colorize_depth(batch['in']['depth']) if self.generator_input_depth else None, viz.colorize_tensor(batch['in']['mask']) if self.generator_input_mask else None, ], row_size=4, stride=2, output_size=64)) with torch.no_grad(): self.plotter.put_image( 'reconstruction', viz.make_grid([ gan_denormalize(recon_image), gan_denormalize(fake_image) if (fake_image is not None) else None, viz.colorize_depth(recon_depth), viz.colorize_depth(fake_depth) if (fake_depth is not None) else None, viz.colorize_tensor( (recon_depth.cpu() - fake_depth.cpu()).abs()) if (fake_depth is not None) else None, viz.colorize_tensor(recon_mask), viz.colorize_tensor(fake_mask) if (fake_mask is not None) else None, viz.colorize_tensor( (recon_mask.cpu() - fake_mask.cpu()).abs()) if (fake_mask is not None) else None, ], stride=8)) plot_images_time = self.mark_time() self.plotter.put_scalar('time/plot/images', plot_images_time)
def _optimize_camera(self, z_obj, target_obs, cameras, iters, ranking): optimizers = [] schedulers = [] param_cameras = [pu.parameterize_camera(camera, optimize_viewport=True) for camera in cameras] for camera in param_cameras: parameters = [camera.log_quaternion, camera.translation, camera.viewport] optimizer = self.get_optimizer(self.optimizer, parameters, lr=self.learning_rate) optimizers.append(optimizer) schedulers.append( optim.lr_scheduler.ReduceLROnPlateau( optimizer, patience=self.lr_reduce_patience, threshold=self.lr_reduce_threshold, factor=self.lr_reduce_factor, verbose=self.verbose)) pbar = utils.trange(iters) stat_history = {} converge_count = 0 camera_history = [] for step in pbar: for optimizer in optimizers: optimizer.zero_grad() cameras = Camera.cat(param_cameras) if self.loss_weights.get('latent', 0.0) > 0.0: with torch.no_grad(): z_target_latent = self.model.compute_latent_code(target_obs, cameras) else: z_target_latent = None z_depth, z_mask, z_mask_logits, z_pred_latent = self._render_observation(z_obj, cameras) optim_weights = copy.copy(self.loss_weights) optim_weights.update({k: v.get(step) for k, v in self.loss_schedules.items()}) loss_dict = self.loss_func(target_obs, z_depth, z_mask_logits, cameras, z_pred_latent=z_pred_latent, z_target_latent=z_target_latent) optim_loss = sum(weigh_losses(loss_dict, optim_weights).values()) optim_loss.mean().backward() rank_loss = sum(weigh_losses(loss_dict, self.loss_weights).values()) best_idx = torch.argmin(rank_loss) detached_cameras = pu.deparameterize_camera(cameras.uncrop()).clone() angle_dists = three.quaternion.angular_distance( detached_cameras.quaternion, target_obs.camera.quaternion).squeeze() translation_dists = torch.norm(detached_cameras.translation - target_obs.camera.translation, dim=1).squeeze() if self.return_camera_history: camera_history.append((rank_loss.detach().cpu(), detached_cameras.cpu())) # Save best cameras in ranking list. delta = self._track_best_items(ranking, step, items=detached_cameras.cpu(), loss=rank_loss) pbar.set_description(f"idx={best_idx}, loss={rank_loss[best_idx].item():.04f}" f", depth={loss_dict['depth'][best_idx].item():.04f}" f", ov_depth={loss_dict['ov_depth'][best_idx].item():.04f}" f", mask={loss_dict['mask'][best_idx].item():.04f}" f", iou={loss_dict['iou'][best_idx].item():.04f}" f", latent={loss_dict.get('latent', [0.0]*len(cameras))[best_idx]:.04f}" f", converge={converge_count}" f", angle={angle_dists[best_idx].item() / math.pi * 180:.02f}°" f", trans={translation_dists[best_idx].item():.04f}" f"") if self.track_stats: self._record_stat_dict(stat_history, { **{f'{k}_loss': v.detach().cpu() for k, v in loss_dict.items()}, **{f'{k}_weight': v for k, v in optim_weights.items()}, 'delta': delta, 'converge_count': converge_count, 'angle_dist': angle_dists.cpu(), 'trans_dist': translation_dists.cpu(), 'optim_loss': optim_loss.detach().cpu(), 'rank_loss': rank_loss.detach().cpu(), # 'translation_grad': (cameras.translation.grad # if cameras.translation.grad is not None # else torch.zeros_like(cameras.translation)), # 'rotation_grad': (cameras.log_quaternion.grad # if cameras.log_quaternion.grad is not None # else torch.zeros_like(cameras.log_quaternion)), # 'viewport_grad': cameras.viewport.grad, }) for i, (optimizer, scheduler) in enumerate(zip(optimizers, schedulers)): optimizer.step() scheduler.step(rank_loss[i]) if delta < self.converge_threshold: converge_count += 1 elif delta > self.converge_threshold: converge_count = 0 if converge_count >= self.converge_patience: logger.info("convergence threshold reached", step=step, delta=delta, count=converge_count) pbar.close() break return stat_history, camera_history