def visualize(self, data): ''' Performs a visualization step for the data. Args: data (dict): data dictionary ''' device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer, ) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) batch_size = gt_pc.size(0) self.model.eval() with torch.no_grad(): if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) if isinstance(out, tuple): out, _ = out for i in trange(batch_size): pc = gt_pc[i].cpu() vis.visualize_pointcloud(pc, out_file=os.path.join( self.vis_dir, '%03d_gt_pc.png' % i)) pc = out[i].cpu() vis.visualize_pointcloud(pc, out_file=os.path.join( self.vis_dir, '%03d_pr_pc.png' % i)) pc = encoder_inputs[i].cpu() vis.visualize_pointcloud(pc, out_file=os.path.join( self.vis_dir, '%03d_input_half_pc.png' % i))
def compute_loss(self, data): ''' Computes the loss. Args: data (dict): data dictionary ''' device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer, ) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) loss = 0 if isinstance(out, tuple): out, trans_feat = out if isinstance(self.model.encoder, PointNetEncoder) or isinstance( self.model.encoder, PointNetResEncoder): loss = loss + 0.001 * feature_transform_reguliarzer(trans_feat) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)).mean() / 2. else: out_pts_count = out.size(1) loss = loss + (emd.earth_mover_distance( out, gt_pc, transpose=False) / out_pts_count).mean() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = loss + self.loss_mask_flow_ratio * loss_mask_flow if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale # absolute depth from view background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = loss + self.loss_depth_test_ratio * loss_depth_test return loss
def eval_step(self, data): ''' Performs an evaluation step. Args: data (dict): data dictionary ''' self.model.eval() device = self.device encoder_inputs, raw_data = compose_inputs( data, mode='train', device=self.device, input_type=self.input_type, depth_pointcloud_transfer=self.depth_pointcloud_transfer) world_mat = None if (self.model.encoder_world_mat is not None) \ or self.gt_pointcloud_transfer in ('view', 'view_scale_model'): if 'world_mat' in raw_data: world_mat = raw_data['world_mat'] else: world_mat = get_world_mat(data, device=device) gt_pc = compose_pointcloud(data, device, self.gt_pointcloud_transfer, world_mat=world_mat) batch_size = gt_pc.size(0) with torch.no_grad(): if self.model.encoder_world_mat is not None: out = self.model(encoder_inputs, world_mat=world_mat) else: out = self.model(encoder_inputs) if isinstance(out, tuple): out, trans_feat = out eval_dict = {} if batch_size == 1: pointcloud_hat = out.cpu().squeeze(0).numpy() pointcloud_gt = gt_pc.cpu().squeeze(0).numpy() eval_dict = self.mesh_evaluator.eval_pointcloud( pointcloud_hat, pointcloud_gt) # chamfer distance loss if self.loss_type == 'cd': dist1, dist2 = cd.chamfer_distance(out, gt_pc) loss = (dist1.mean(1) + dist2.mean(1)) / 2. else: loss = emd.earth_mover_distance(out, gt_pc, transpose=False) if self.gt_pointcloud_transfer in ('world_scale_model', 'view_scale_model', 'view'): pointcloud_scale = data.get('pointcloud.scale').to( device).view(batch_size, 1, 1) loss = loss / (pointcloud_scale**2) if self.gt_pointcloud_transfer == 'view': if world_mat is None: world_mat = get_world_mat(data, device=device) t_scale = world_mat[:, 2:, 3:] loss = loss * (t_scale**2) if self.loss_type == 'cd': loss = loss.mean() eval_dict['chamfer'] = loss.item() else: out_pts_count = out.size(1) loss = (loss / out_pts_count).mean() eval_dict['emd'] = loss.item() # view penalty loss if self.view_penalty: gt_mask_flow = data.get('inputs.mask_flow').to( device) # B * 1 * H * W if world_mat is None: world_mat = get_world_mat(data, device=device) camera_mat = get_camera_mat(data, device=device) # projection use world mat & camera mat if self.gt_pointcloud_transfer == 'world_scale_model': out_pts = transform_points(out, world_mat) elif self.gt_pointcloud_transfer == 'view_scale_model': t = world_mat[:, :, 3:] out_pts = out_pts + t elif self.gt_pointcloud_transfer == 'view': t = world_mat[:, :, 3:] out_pts = out_pts * t[:, 2:, :] out_pts = out_pts + t else: raise NotImplementedError out_pts_img = project_to_camera(out_pts, camera_mat) out_pts_img = out_pts_img.unsqueeze(1) # B * 1 * n_pts * 2 out_mask_flow = F.grid_sample(gt_mask_flow, out_pts_img) # B * 1 * 1 * n_pts loss_mask_flow = F.relu(1. - self.mask_flow_eps - out_mask_flow, inplace=True).mean() loss = self.loss_mask_flow_ratio * loss_mask_flow eval_dict['loss_mask_flow'] = loss.item() if self.view_penalty == 'mask_flow_and_depth': # depth test loss t_scale = world_mat[:, 2, 3].view(world_mat.size(0), 1, 1, 1) gt_mask = data.get('inputs.mask').byte().to(device) depth_pred = data.get('inputs.depth_pred').to( device) * t_scale background_setting(depth_pred, gt_mask) depth_z = out_pts[:, :, 2:].transpose(1, 2) corresponding_z = F.grid_sample( depth_pred, out_pts_img) # B * 1 * 1 * n_pts corresponding_z = corresponding_z.squeeze(1) # eps = 0.05 loss_depth_test = F.relu(depth_z - self.depth_test_eps - corresponding_z, inplace=True).mean() loss = self.loss_depth_test_ratio * loss_depth_test eval_dict['loss_depth_test'] = loss.item() return eval_dict
def compose_inputs(data, mode='train', device=None, input_type='depth_pred', use_gt_depth_map=False, depth_map_mix=False, with_img=False, depth_pointcloud_transfer=None, local=False): assert mode in ('train', 'val', 'test') raw_data = {} if input_type == 'depth_pred': gt_mask = data.get('inputs.mask').to(device).byte() raw_data['mask'] = gt_mask batch_size = gt_mask.size(0) if use_gt_depth_map: gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) encoder_inputs = gt_depth_maps raw_data['depth'] = gt_depth_maps else: pr_depth_maps = data.get('inputs.depth_pred').to(device) background_setting(pr_depth_maps, gt_mask) raw_data['depth_pred'] = pr_depth_maps if depth_map_mix and mode == 'train': gt_depth_maps = data.get('inputs.depth').to(device) background_setting(gt_depth_maps, gt_mask) raw_data['depth'] = gt_depth_maps alpha = torch.rand(batch_size, 1, 1, 1).to(device) pr_depth_maps = pr_depth_maps * alpha + gt_depth_maps * (1.0 - alpha) encoder_inputs = pr_depth_maps if with_img: img = data.get('inputs').to(device) raw_data[None] = img encoder_inputs = {'img': img, 'depth': encoder_inputs} if local: camera_args = get_camera_args(data, 'points.loc', 'points.scale', device=device) Rt = camera_args['Rt'] K = camera_args['K'] encoder_inputs = { None: encoder_inputs, 'world_mat': Rt, 'camera_mat': K, } raw_data['world_mat'] = Rt raw_data['camera_mat'] = K return encoder_inputs, raw_data elif input_type == 'depth_pointcloud': encoder_inputs = data.get('inputs.depth_pointcloud').to(device) if depth_pointcloud_transfer is not None: if depth_pointcloud_transfer in ('world', 'world_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat R = world_mat[:, :, :3] # R's inverse is R^T encoder_inputs = transform_points(encoder_inputs, R.transpose(1, 2)) # or encoder_inputs = transform_points_back(encoder_inputs, R) if depth_pointcloud_transfer == 'world_scale_model': t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] elif depth_pointcloud_transfer in ('view', 'view_scale_model'): encoder_inputs = encoder_inputs[:, :, [1, 0, 2]] if depth_pointcloud_transfer == 'view_scale_model': world_mat = get_world_mat(data, transpose=None, device=device) raw_data['world_mat'] = world_mat t = world_mat[:, :, 3:] encoder_inputs = encoder_inputs * t[:, 2:, :] else: raise NotImplementedError raw_data['depth_pointcloud'] = encoder_inputs if local: #assert depth_pointcloud_transfer.startswith('world') encoder_inputs = {None: encoder_inputs} return encoder_inputs, raw_data else: raise NotImplementedError