def depth_to_L(pr_depth_map, gt_mask): #not inplace function pr_depth_map_max = torch.max(pr_depth_map[gt_mask]) pr_depth_map_min = torch.min(pr_depth_map[gt_mask]) background_setting(pr_depth_map, gt_mask, pr_depth_map_max) pr_depth_map = (pr_depth_map - pr_depth_map_min) / (pr_depth_map_max - pr_depth_map_min) return pr_depth_map
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device p = data.get('points').to(device) batch_size = p.size(0) occ = data.get('points.occ').to(device) inputs = data.get('inputs').to(device) gt_mask = data.get('inputs.mask').to(device).byte() if self.training_detach: with torch.no_grad(): pr_depth_maps = self.model.predict_depth_map(inputs) else: pr_depth_maps = self.model.predict_depth_map(inputs) background_setting(pr_depth_maps, gt_mask) if self.depth_map_mix: gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) alpha = torch.rand(batch_size, 1, 1, 1).to(device) pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 - alpha) kwargs = {} c = self.model.encode(pr_depth_maps) q_z = self.model.infer_z(p, occ, c, **kwargs) z = q_z.rsample() # KL-divergence kl = dist.kl_divergence(q_z, self.model.p0_z).sum(dim=-1) loss = kl.mean() # General points p_r = self.model.decode(p, z, c, **kwargs) logits = p_r.logits probs = p_r.probs # loss loss_i = get_occ_loss(logits, occ, self.loss_type) # loss strategies loss_i = occ_loss_postprocess(loss_i, occ, probs, self.loss_tolerance_episolon, self.sign_lambda, self.threshold, self.surface_loss_weight) loss = loss + loss_i.sum(-1).mean() return loss
def forward(self, data, device): gt_depth_maps = data.get('inputs.depth').to(device) gt_mask = data.get('inputs.mask').to(device).byte() background_setting(gt_depth_maps, gt_mask) encoder_inputs = gt_depth_maps if self.with_img: img = data.get('inputs').to(device) encoder_inputs = {'img': img, 'depth': encoder_inputs} out = self.features(encoder_inputs) out = self.pred_fc(out) return out
def visualize(self, data): ''' Performs a visualization step for the data. Args: data (dict): data dictionary ''' device = self.device batch_size = data['points'].size(0) inputs = data.get('inputs').to(device) #gt_depth_maps = data.get('inputs.depth').to(device) gt_mask = data.get('inputs.mask').to(device).byte() shape = (32, 32, 32) p = make_3d_grid([-0.5] * 3, [0.5] * 3, shape).to(device) p = p.expand(batch_size, *p.size()) kwargs = {} with torch.no_grad(): pr_depth_maps = self.model.predict_depth_map(inputs) background_setting(pr_depth_maps, gt_mask) p_r = self.model.forward_halfway(p, pr_depth_maps, sample=self.eval_sample, **kwargs) occ_hat = p_r.probs.view(batch_size, *shape) voxels_out = (occ_hat >= self.threshold).cpu().numpy() for i in trange(batch_size): input_img_path = os.path.join(self.vis_dir, '%03d_in.png' % i) vis.visualize_data(inputs[i].cpu(), 'img', input_img_path) vis.visualize_voxels(voxels_out[i], os.path.join(self.vis_dir, '%03d.png' % i)) depth_map_path = os.path.join(self.vis_dir, '%03d_pr_depth.png' % i) depth_map = pr_depth_maps[i].cpu() depth_map = depth_to_L(depth_map, gt_mask[i].cpu()) vis.visualize_data(depth_map, 'img', depth_map_path)
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer, ) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) loss = 0 if isinstance(out, tuple): out, trans_feat = out if isinstance(self.model.encoder, PointNetEncoder) or isinstance( self.model.encoder, PointNetResEncoder): loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2. else: out_pts_count = out.size(1) loss = loss + (emd.earth_mover_distance( out, gt_pc, transpose=False) / out_pts_count).mean() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = loss + self.loss_mask_flow_ratio * loss_mask_flow if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale # absolute depth from view background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = loss + self.loss_depth_test_ratio * loss_depth_test return loss
def eval_step(self, data): ''' Performs an evaluation step. Args: data (dict): data dictionary ''' self.model.eval() device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) batch_size = gt_pc.size(0) with torch.no_grad(): if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) if isinstance(out, tuple): out, trans_feat = out eval_dict = {} if batch_size == 1: pointcloud_hat = out.cpu().squeeze(0).numpy() pointcloud_gt = gt_pc.cpu().squeeze(0).numpy() eval_dict = self.mesh_evaluator.eval_pointcloud( pointcloud_hat, pointcloud_gt) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)) / 2. else: loss = emd.earth_mover_distance(out, gt_pc, transpose=False) if self.gt_pointcloud_transfer in ('world_scale_model', 'view_scale_model', 'view'): pointcloud_scale = data.get('pointcloud.scale').to( device).view(batch_size, 1, 1) loss = loss / (pointcloud_scale**2) if self.gt_pointcloud_transfer == 'view': if world_mat is None: world_mat = get_world_mat(data, device=device) t_scale = world_mat[:, 2:, 3:] loss = loss * (t_scale**2) if self.loss_type == 'cd': loss = loss.mean() eval_dict['chamfer'] = loss.item() else: out_pts_count = out.size(1) loss = (loss / out_pts_count).mean() eval_dict['emd'] = loss.item() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = self.loss_mask_flow_ratio * loss_mask_flow eval_dict['loss_mask_flow'] = loss.item() if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps = 0.05 loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = self.loss_depth_test_ratio * loss_depth_test eval_dict['loss_depth_test'] = loss.item() return eval_dict
def compose_inputs(data, mode='train', device=None, input_type='depth_pred', use_gt_depth_map=False, depth_map_mix=False, with_img=False, depth_pointcloud_transfer=None, local=False): assert mode in ('train', 'val', 'test') raw_data = {} if input_type == 'depth_pred': gt_mask = data.get('inputs.mask').to(device).byte() raw_data['mask'] = gt_mask batch_size = gt_mask.size(0) if use_gt_depth_map: gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) encoder_inputs = gt_depth_maps raw_data['depth'] = gt_depth_maps else: pr_depth_maps = data.get('inputs.depth_pred').to(device) background_setting(pr_depth_maps, gt_mask) raw_data['depth_pred'] = pr_depth_maps if depth_map_mix and mode == 'train': gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) raw_data['depth'] = gt_depth_maps alpha = torch.rand(batch_size, 1, 1, 1).to(device) pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 - alpha) encoder_inputs = pr_depth_maps if with_img: img = data.get('inputs').to(device) raw_data[None] = img encoder_inputs = {'img': img, 'depth': encoder_inputs} if local: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] encoder_inputs = { None: encoder_inputs, 'world_mat': Rt, 'camera_mat': K, } raw_data['world_mat'] = Rt raw_data['camera_mat'] = K return encoder_inputs, raw_data elif input_type == 'depth_pointcloud': encoder_inputs = data.get('inputs.depth_pointcloud').to(device) if depth_pointcloud_transfer is not None: if depth_pointcloud_transfer in ('world', 'world_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat R = world_mat[:, :, :3] # R's inverse is R^T encoder_inputs = transform_points(encoder_inputs, R.transpose(1, 2)) # or encoder_inputs = transform_points_back(encoder_inputs, R) if depth_pointcloud_transfer == 'world_scale_model': t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] elif depth_pointcloud_transfer in ('view', 'view_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] if depth_pointcloud_transfer == 'view_scale_model': world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] else: raise NotImplementedError raw_data['depth_pointcloud'] = encoder_inputs if local: #assert depth_pointcloud_transfer.startswith('world') encoder_inputs = {None: encoder_inputs} return encoder_inputs, raw_data else: raise NotImplementedError
def generate_mesh(self, data, return_stats=True): ''' Generates the output mesh. Args: data (tensor): data tensor return_stats (bool): whether stats should be returned ''' self.model.eval() device = self.device stats_dict = {} if self.input_type in ('depth_pred', 'depth_pointcloud'): encoder_inputs, _ = compose_inputs( data, mode='test', device=self.device, input_type=self.input_type, use_gt_depth_map=self.use_gt_depth_map, depth_map_mix=False, with_img=self.with_img, depth_pointcloud_transfer=self.depth_pointcloud_transfer, local=self.local) else: # Preprocess if requires inputs = data.get('inputs').to(device) gt_mask = data.get('inputs.mask').to(device).byte() if self.preprocessor is not None: t0 = time.time() with torch.no_grad(): inputs = self.preprocessor(inputs) stats_dict['time (preprocess)'] = time.time() - t0 t0 = time.time() with torch.no_grad(): depth = self.model.predict_depth_map(inputs) stats_dict['time (predict depth map)'] = time.time() - t0 background_setting(depth, gt_mask) encoder_inputs = depth kwargs = {} # Encode inputs t0 = time.time() with torch.no_grad(): if self.local: c = self.model.encoder.forward_local_first_step(encoder_inputs) else: c = self.model.encode(encoder_inputs) stats_dict['time (encode)'] = time.time() - t0 z = self.model.get_z_from_prior((1, ), sample=self.sample).to(device) if self.local: mesh = self.generate_from_latent(z, c, data=encoder_inputs, stats_dict=stats_dict, **kwargs) else: mesh = self.generate_from_latent(z, c, stats_dict=stats_dict, **kwargs) if return_stats: return mesh, stats_dict else: return mesh