def _get_kappa_adv(adv_pc, ori_pc, ori_normal, k=2): b, _, n = adv_pc.size() # compute knn between advPC and oriPC to get normal n_p #intra_dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1) #intra_idx = torch.topk(intra_dis, 1, dim=2, largest=False, sorted=True)[1] #normal = torch.gather(ori_normal, 2, intra_idx.view(b,1,n).expand(b,3,n)) intra_KNN = knn_points(adv_pc.permute(0, 2, 1), ori_pc.permute(0, 2, 1), K=1) #[dists:[b,n,1], idx:[b,n,1]] normal = knn_gather(ori_normal.permute(0, 2, 1), intra_KNN.idx).permute( 0, 3, 1, 2).squeeze(3).contiguous() # [b, 3, n] # compute knn between advPC and itself to get \|q-p\|_2 #inter_dis = ((adv_pc.unsqueeze(3) - adv_pc.unsqueeze(2))**2).sum(1) #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous() #nn_pts = torch.gather(adv_pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k) inter_KNN = knn_points(adv_pc.permute(0, 2, 1), adv_pc.permute(0, 2, 1), K=k + 1) #[dists:[b,n,k+1], idx:[b,n,k+1]] nn_pts = knn_gather(adv_pc.permute(0, 2, 1), inter_KNN.idx).permute( 0, 3, 1, 2)[:, :, :, 1:].contiguous() # [b, 3, n ,k] vectors = nn_pts - adv_pc.unsqueeze(3) vectors = _normalize(vectors) return torch.abs( (vectors * normal.unsqueeze(3)).sum(1)).mean(2), normal # [b, n], [b, 3, n]
def get_normal_w(self, point_clouds: PointClouds3D, normals: Optional[torch.Tensor] = None, **kwargs): """ Weights exp(-\|n-ni\|^2/sharpness_sigma^2), for i in a local neighborhood Args: point_clouds: whose normals will be used for ni normals (tensor): (N, maxP, 3) padded normals as n, if not provided, use the normals from point_clouds Returns: weight per point per neighbor (N,maxP,K) """ self.sharpness_sigma = kwargs.get('sharpness_sigma', self.sharpness_sigma) inv_sigma_normal = 1 / (self.sharpness_sigma * self.sharpness_sigma) lengths = point_clouds.num_points_per_cloud() if normals is None: normals = point_clouds.normals_padded() knn_normals = ops.knn_gather(normals, self.knn_tree.idx, lengths) normals = torch.nn.functional.normalize(normals, dim=-1) knn_normals = torch.nn.functional.normalize(knn_normals, dim=-1) w = knn_normals - normals[:, :, None, :] w = torch.exp(-torch.sum(w * w, dim=-1) * inv_sigma_normal) return w
def _get_kappa_ori(pc, normal, k=2): b, _, n = pc.size() #inter_dis = ((pc.unsqueeze(3) - pc.unsqueeze(2))**2).sum(1) #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous() #nn_pts = torch.gather(pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k) inter_KNN = knn_points(pc.permute(0, 2, 1), pc.permute(0, 2, 1), K=k + 1) #[dists:[b,n,k+1], idx:[b,n,k+1]] nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute( 0, 3, 1, 2)[:, :, :, 1:].contiguous() # [b, 3, n ,k] vectors = nn_pts - pc.unsqueeze(3) vectors = _normalize(vectors) return torch.abs((vectors * normal.unsqueeze(3)).sum(1)).mean(2) # [b, n]
def _denoise_normals(self, point_clouds, weights, point_clouds_filter=None): """ robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average from neighboring normals do this only for invisible points (?) Args: weights (tensors): (N,max_P,K) """ lengths = point_clouds.num_points_per_cloud() P_total = lengths.sum().item() normals = point_clouds.normals_padded() batch_size, max_P, _ = normals.shape knn_normals = ops.knn_gather(normals, self.knn_tree.idx, lengths) normals_denoised = torch.sum(knn_normals * weights[:, :, :, None], dim=-2) / \ eps_denom(torch.sum(weights, dim=-1, keepdim=True)) # get point visibility so that we update only the non-visible or out-of-mask normals if point_clouds_filter is not None: try: reliable_normals_mask = point_clouds_filter.visibility & point_clouds_filter.inmask if len(point_clouds) != reliable_normals_mask.shape[0]: if len(point_clouds ) == 1 and reliable_normals_mask.shape[0] > 1: reliable_normals_mask = reliable_normals_mask.any( dim=0, keepdim=True) else: ValueError( "Incompatible point clouds {} and mask {}".format( len(point_clouds), reliable_normals_mask.shape)) # found visibility 0/1 as the last dimension of the features # reset visible points normals to its original ones normals_denoised[pts_reliable_normals_mask == 1] = normals[ reliable_normals_mask == 1] except KeyError as e: pass normals_packed = point_clouds.normals_packed() normals_denoised_packed = ops.padded_to_packed( normals_denoised, point_clouds.cloud_to_packed_first_idx(), P_total) point_clouds.update_normals_(normals_denoised_packed) return point_clouds
def knn_group_withI(points1, points2, intensity2, k): ''' Input: points1: [B,3,N] points2: [B,3,N] intensity2: [B,1,N] ''' points1 = points1.permute(0,2,1).contiguous() points2 = points2.permute(0,2,1).contiguous() _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True) points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1) # [B,M,k,3] grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True) grouped_features = knn_gather(intensity2.permute(0,2,1), nn_idx) # [B,M,k,1] new_features = torch.cat([points_resi, grouped_dist], dim=-1) # [B,5,M,k], [B,3,M,k], [B,1,M,k] return new_features.permute(0,3,1,2).contiguous(), \ nn.permute(0,3,1,2).contiguous(), \ grouped_features.permute(0,3,1,2).contiguous()
def estimate_normal_via_ori_normal(pc_adv, pc_ori, normal_ori, k): # pc_adv, pc_ori, normal_ori : [b,3,n] b, _, n = pc_adv.size() intra_KNN = knn_points(pc_adv.permute(0, 2, 1), pc_ori.permute(0, 2, 1), K=k) #[dists:[b,n,k], idx:[b,n,k]] inter_value = intra_KNN.dists[:, :, 0].contiguous() inter_idx = intra_KNN.idx.permute(0, 2, 1).contiguous() normal_pts = knn_gather(normal_ori.permute( 0, 2, 1), intra_KNN.idx).permute(0, 3, 1, 2).contiguous() # [b, 3, n ,k] normal_pts_avg = normal_pts.mean(dim=-1) normal_pts_avg = normal_pts_avg / (normal_pts_avg.norm(dim=1) + 1e-12) # If the points are not modified (distance = 0), use the normal directly from the original # one. Otherwise, use the mean of the normals of the k-nearest points. normal_ori_select = normal_pts[:, :, :, 0] condition = (inter_value < 1e-6).unsqueeze(1).expand_as(normal_ori_select) normals_estimated = torch.where(condition, normal_ori_select, normal_pts_avg) return normals_estimated
def knn_group(self, points1, points2, features2, k): ''' For each point in points1, query kNN points/features in points2/features2 Input: points1: [B,3,N] points2: [B,3,N] features2: [B,C,N] Output: new_features: [B,4,N] nn: [B,3,N] grouped_features: [B,C,N] ''' points1 = points1.permute(0,2,1).contiguous() points2 = points2.permute(0,2,1).contiguous() _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True) points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1) grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True) grouped_features = knn_gather(features2.permute(0,2,1), nn_idx) new_features = torch.cat([points_resi, grouped_dist], dim=-1) return new_features.permute(0,3,1,2).contiguous(),\ nn.permute(0,3,1,2).contiguous(),\ grouped_features.permute(0,3,1,2).contiguous()
def forward(self, depth_img_dict_1, depth_img_dict_2, cam_info, R12, t12, R21, t21): ### format them into point clouds pcl_gt_1, pcl_gt_1_in_2 = self.load_pcl(cam_info, depth_img_dict_1, R21, t21) pcl_gt_2, pcl_gt_2_in_1 = self.load_pcl(cam_info, depth_img_dict_2, R12, t12) ### knn dists_1, idxs_1, pcl_knn_to_1 = knn_pcl(pcl_gt_1, pcl_gt_2_in_1, self.num_nn) #(N, P1, D), (N, P2, D) -> (N, P1, K), (N, P1, K), (N, P1, K, D) dists_2, idxs_2, pcl_knn_to_2 = knn_pcl(pcl_gt_2, pcl_gt_1_in_2, self.num_nn) ### distance kernel ell = self.ell_min + self.ell_rand length_scale_1 = ell * (pcl_gt_1.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist # points_padded is (N, P, D), length_scale is (N,P,1) length_scale_1 = length_scale_1.clamp(min=ell).pow(2) length_scale_2 = ell * (pcl_gt_2.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist # points_padded is (N, P, D), length_scale is (N,P,1) length_scale_2 = length_scale_2.clamp(min=ell).pow(2) dist_kernel_1 = torch.exp( - dists_1 / length_scale_1 ) dist_kernel_2 = torch.exp( - dists_2 / length_scale_2 ) ### color kernel color_scale = 0.2 pcl_hsv_knn_to_1 = knn_gather(pcl_gt_2_in_1.features_padded()[:, :, :3], idxs_1, pcl_gt_2_in_1.num_points_per_cloud() ) # (N, P1, K, D) pcl_hsv_knn_to_2 = knn_gather(pcl_gt_1_in_2.features_padded()[:, :, :3], idxs_2, pcl_gt_1_in_2.num_points_per_cloud() ) color_dist_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_1).norm(-1) # .pow(2).sum(-1) # (N, P1, K) color_dist_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_2).norm(-1) # .pow(2).sum(-1) # (N, P2, K) color_kernel_1 = torch.exp( - color_dist_1 / color_scale ) color_kernel_2 = torch.exp( - color_dist_2 / color_scale ) ### calculate normal normal_gt_1 = pcl_gt_1.normals_padded() normal_gt_2 = pcl_gt_2.normals_padded() normal_gt_1_in_2 = pcl_gt_1_in_2.normals_padded() normal_gt_2_in_1 = pcl_gt_2_in_1.normals_padded() ### nres nres_gt_1 = pcl_gt_1.features_padded()[:,:,[3]] nres_gt_2 = pcl_gt_2.features_padded()[:,:,[3]] nres_gt_1_in_2 = pcl_gt_1_in_2.features_padded()[:,:,[3]] nres_gt_2_in_1 = pcl_gt_2_in_1.features_padded()[:,:,[3]] normal_knn_to_1 = knn_gather(normal_gt_2_in_1, idxs_1, pcl_gt_2_in_1.num_points_per_cloud() ) # (N, P1, K, D) normal_knn_to_2 = knn_gather(normal_gt_1_in_2, idxs_2, pcl_gt_1_in_2.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_to_1 = knn_gather(nres_gt_2_in_1, idxs_1, pcl_gt_2_in_1.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_to_2 = knn_gather(nres_gt_1_in_2, idxs_2, pcl_gt_1_in_2.num_points_per_cloud() ) # (N, P1, K, D) alpha_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_to_1 ).squeeze(-1) alpha_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_to_2 ).squeeze(-1) normal_kernel_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_to_1).sum(-1)*alpha_1).clamp(min=0) normal_kernel_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_to_2).sum(-1)*alpha_2).clamp(min=0) ### final kernel mask_1 = mask_from_pcl(pcl_gt_1) mask_2 = mask_from_pcl(pcl_gt_2) kernel_1 = (dist_kernel_1 * color_kernel_1 * normal_kernel_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn) kernel_2 = (dist_kernel_2 * color_kernel_2 * normal_kernel_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn) if self.log_loss: inp = (kernel_1.log() + kernel_2.log()) / 2 else: inp = (kernel_1 + kernel_2) / 2 return inp
def forward(self, depth_img_dict_1=None, depth_img_dict_2=None, flow_dict_1to2=None, flow_dict_2to1=None, cam_info=None): ### format them into point clouds pcl_pred_1, pcl_gt_1, pcl_flowed_2_from_1 = self.load_pcl(cam_info, depth_img_dict_1, flow_dict_1to2) pcl_pred_2, pcl_gt_2, pcl_flowed_1_from_2 = self.load_pcl(cam_info, depth_img_dict_2, flow_dict_2to1) # logging.info("n gt: {}".format(pcl_gt_1.num_points_per_cloud())) # logging.info("n pred: {}".format(pcl_pred_1.num_points_per_cloud())) # logging.info("n flow: {}".format(pcl_flowed_1_from_2.num_points_per_cloud())) # pclpad_pred_1 = pcl_pred_1.points_padded() # pclpad_pred_2 = pcl_pred_2.points_padded() # pclpad_gt_1 = pcl_gt_1.points_padded() # pclpad_gt_2 = pcl_gt_2.points_padded() # pclpad_flowed_2_from_1 = pcl_flowed_2_from_1.points_padded() # pclpad_flowed_1_from_2 = pcl_flowed_1_from_2.points_padded() ### knn dists_1, idxs_1, pcl_knn_to_1 = knn_pcl(pcl_gt_1, pcl_pred_1, self.num_nn) #(N, P1, D), (N, P2, D) -> (N, P1, K), (N, P1, K), (N, P1, K, D) dists_2, idxs_2, pcl_knn_to_2 = knn_pcl(pcl_gt_2, pcl_pred_2, self.num_nn) dists_1_from_2, idxs_1_from_2, pcl_knn_flowed_to_1 = knn_pcl(pcl_gt_1, pcl_flowed_1_from_2, self.num_nn) dists_2_from_1, idxs_2_from_1, pcl_knn_flowed_to_2 = knn_pcl(pcl_gt_2, pcl_flowed_2_from_1, self.num_nn) # pcl_knn_to_1 = pcl_from_knnidx(pcl_pred_1, idxs_1) # pcl_knn_to_2 = pcl_from_knnidx(pcl_pred_2, idxs_2) # pytorch3d.ops.knn_points(pclpad_gt_1, pclpad_pred_1, num_points_per_cloud, lengths2: Optional[torch.Tensor] = None, K: int = 1, version: int = -1, return_nn: bool = False, return_sorted: bool = True) ### distance kernel ell = self.ell_min + self.ell_rand length_scale_1 = ell * (pcl_gt_1.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist # points_padded is (N, P, D), length_scale is (N,P,1) length_scale_1 = length_scale_1.clamp(min=ell).pow(2) length_scale_2 = ell * (pcl_gt_2.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist # points_padded is (N, P, D), length_scale is (N,P,1) length_scale_2 = length_scale_2.clamp(min=ell).pow(2) dist_kernel_1 = torch.exp( - dists_1 / length_scale_1 ) dist_kernel_flowed_1 = torch.exp( - dists_1_from_2 / length_scale_1 ) dist_kernel_2 = torch.exp( - dists_2 / length_scale_2 ) dist_kernel_flowed_2 = torch.exp( - dists_2_from_1 / length_scale_2 ) ### color kernel color_scale = 0.2 pcl_hsv_knn_to_1 = knn_gather(pcl_pred_1.features_padded()[:, :, :3], idxs_1, pcl_pred_1.num_points_per_cloud() ) # (N, P1, K, D) pcl_hsv_knn_to_2 = knn_gather(pcl_pred_2.features_padded()[:, :, :3], idxs_2, pcl_pred_2.num_points_per_cloud() ) pcl_hsv_knn_flowed_to_1 = knn_gather(pcl_flowed_1_from_2.features_padded()[:, :, :3], idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() ) pcl_hsv_knn_flowed_to_2 = knn_gather(pcl_flowed_2_from_1.features_padded()[:, :, :3], idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() ) color_dist_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_1).norm(-1) # .pow(2).sum(-1) # (N, P1, K) color_dist_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_2).norm(-1) # .pow(2).sum(-1) # (N, P2, K) color_dist_flowed_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_flowed_to_1).norm(-1) # .pow(2).sum(-1) # (N, P1, K) color_dist_flowed_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_flowed_to_2).norm(-1) # .pow(2).sum(-1) # (N, P2, K) color_kernel_1 = torch.exp( - color_dist_1 / color_scale ) color_kernel_2 = torch.exp( - color_dist_2 / color_scale ) color_kernel_flowed_1 = torch.exp( - color_dist_flowed_1 / color_scale ) color_kernel_flowed_2 = torch.exp( - color_dist_flowed_2 / color_scale ) ### normal kernel # ### calculate normal # normal_gt_1 = estimate_pointcloud_normals(pcl_gt_1) # (N, P, 3) # normal_gt_2 = estimate_pointcloud_normals(pcl_gt_2) # normal_pred_1 = estimate_pointcloud_normals(pcl_pred_1) # normal_pred_2 = estimate_pointcloud_normals(pcl_pred_2) # if flow_dict_1to2 is not None and flow_dict_2to1 is not None: # normal_flowed_2_from_1 = estimate_pointcloud_normals(pcl_flowed_2_from_1) # normal_flowed_1_from_2 = estimate_pointcloud_normals(pcl_flowed_1_from_2) ### calculate normal normal_gt_1 = pcl_gt_1.normals_padded() normal_gt_2 = pcl_gt_2.normals_padded() normal_pred_1 = pcl_pred_1.normals_padded() normal_pred_2 = pcl_pred_2.normals_padded() if flow_dict_1to2 is not None and flow_dict_2to1 is not None: normal_flowed_2_from_1 = pcl_flowed_2_from_1.normals_padded() normal_flowed_1_from_2 = pcl_flowed_1_from_2.normals_padded() ### nres nres_gt_1 = pcl_gt_1.features_padded()[:,:,[3]] nres_gt_2 = pcl_gt_2.features_padded()[:,:,[3]] nres_pred_1 = pcl_pred_1.features_padded()[:,:,[3]] nres_pred_2 = pcl_pred_2.features_padded()[:,:,[3]] if flow_dict_1to2 is not None and flow_dict_2to1 is not None: nres_flowed_2_from_1 = pcl_flowed_2_from_1.features_padded()[:,:,[3]] nres_flowed_1_from_2 = pcl_flowed_1_from_2.features_padded()[:,:,[3]] # float res = pts_nres[0][0][in] + grid_nres[ib][0][v+innh][u+innw]; # float alpha = 2 * mag_min / (2*mag_min/mag_max + res); normal_knn_to_1 = knn_gather(normal_pred_1, idxs_1, pcl_pred_1.num_points_per_cloud() ) # (N, P1, K, D) normal_knn_to_2 = knn_gather(normal_pred_2, idxs_2, pcl_pred_2.num_points_per_cloud() ) # (N, P1, K, D) normal_knn_flowed_to_1 = knn_gather(normal_flowed_1_from_2, idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() ) # (N, P1, K, D) normal_knn_flowed_to_2 = knn_gather(normal_flowed_2_from_1, idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_to_1 = knn_gather(nres_pred_1, idxs_1, pcl_pred_1.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_to_2 = knn_gather(nres_pred_2, idxs_2, pcl_pred_2.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_flowed_to_1 = knn_gather(nres_flowed_1_from_2, idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() ) # (N, P1, K, D) nres_knn_flowed_to_2 = knn_gather(nres_flowed_2_from_1, idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() ) # (N, P1, K, D) alpha_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_to_1 ).squeeze(-1) alpha_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_to_2 ).squeeze(-1) alpha_flowed_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_flowed_to_1 ).squeeze(-1) alpha_flowed_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_flowed_to_2 ).squeeze(-1) normal_kernel_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_to_1).sum(-1)*alpha_1).clamp(min=0) normal_kernel_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_to_2).sum(-1)*alpha_2).clamp(min=0) normal_kernel_flowed_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_flowed_to_1).sum(-1)*alpha_flowed_1).clamp(min=0) normal_kernel_flowed_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_flowed_to_2).sum(-1)*alpha_flowed_2).clamp(min=0) ### final kernel mask_1 = mask_from_pcl(pcl_gt_1) mask_2 = mask_from_pcl(pcl_gt_2) kernel_1 = (dist_kernel_1 * color_kernel_1 * normal_kernel_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn) kernel_2 = (dist_kernel_2 * color_kernel_2 * normal_kernel_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn) kernel_flowed_1 = (dist_kernel_flowed_1 * color_kernel_flowed_1 * normal_kernel_flowed_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn) kernel_flowed_2 = (dist_kernel_flowed_2 * color_kernel_flowed_2 * normal_kernel_flowed_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn) if self.log_loss: inp = kernel_1.log() + kernel_2.log() + kernel_flowed_1.log() + kernel_flowed_2.log() else: inp = kernel_1 + kernel_2 + kernel_flowed_1 + kernel_flowed_2 return inp
def compute(self, point_clouds: PointClouds3D, points_filters=None, rebuild_knn=True, **kwargs): self.knn_tree = kwargs.get('knn_tree', self.knn_tree) self.knn_mask = kwargs.get('knn_mask', self.knn_mask) lengths = point_clouds.num_points_per_cloud() P_total = lengths.sum().item() points_padded = point_clouds.points_padded() # Compute necessary weights to project points to local plane # TODO(yifan): This part is same as ProjectionLoss # how can we at best save repetitive computation with torch.autograd.no_grad(): if rebuild_knn or self.knn_tree is None or points_padded.shape[: 2] != self.knn_tree.shape[: 2]: self._build_knn(point_clouds) phi = self.get_phi(point_clouds, **kwargs) self._denoise_normals(point_clouds, phi, points_filters) # compute wn and wr # TODO(yifan): visibility weight? normal_w = self.get_normal_w(point_clouds, **kwargs) # update normals for a second iteration (?) Eq.(10) point_clouds = self._denoise_normals(point_clouds, phi * normal_w, points_filters) # compose weights weights = phi * normal_w weights[~self.knn_mask] = 0 # outside filter_scale*local_point_spacing weights mask_ball_query = self.knn_tree.dists > ( self.filter_scale * self.knn_tree.dists[:, :, :1] * 2.0) weights[mask_ball_query] = 0.0 # project the point to a local surface knn_normals = ops.knn_gather(point_clouds.normals_padded(), self.knn_tree.idx, lengths) dist_to_surface = torch.sum( (self.knn_tree.knn.detach() - points_padded.unsqueeze(-2)) * knn_normals, dim=-1) deltap = torch.sum( dist_to_surface[..., None] * weights[..., None] * knn_normals, dim=-2) / eps_denom(torch.sum(weights, dim=-1, keepdim=True)) points_projected = points_padded + deltap if get_debugging_mode(): # points_padded.requires_grad_(True) def save_grad(): lengths = point_clouds.num_points_per_cloud() def _save_grad(grad): dbg_tensor = get_debugging_tensor() if dbg_tensor is None: logger_py.error("dbg_tensor is None") if grad is None: logger_py.error('grad is None') # a dict of list of tensors dbg_tensor.pts_world_grad['repel'] = [ grad[b, :lengths[b]].detach().cpu() for b in range(grad.shape[0]) ] return _save_grad dbg_tensor = get_debugging_tensor() dbg_tensor.pts_world['repel'] = [ points_padded[b, :lengths[b]].detach().cpu() for b in range(points_padded.shape[0]) ] handle = points_padded.register_hook(save_grad()) self.hooks.append(handle) with torch.autograd.no_grad(): spatial_w = self.get_spatial_w(point_clouds, points_projected) # density_w = self.get_density_w(point_clouds) # density weight is actually spatial_w + 1 density_w = torch.sum(spatial_w, dim=-1, keepdim=True) + 1.0 weights = normal_w * spatial_w * density_w weights[~self.knn_mask] = 0 weights[mask_ball_query] = 0 deltap = points_projected[:, :, None, :] - self.knn_tree.knn.detach() point_to_point_dist = torch.sum(deltap * deltap, dim=-1) # convert everything to packed weights = ops.padded_to_packed( weights, point_clouds.cloud_to_packed_first_idx(), P_total) point_to_point_dist = ops.padded_to_packed( point_to_point_dist, point_clouds.cloud_to_packed_first_idx(), P_total) # we want to maximize this, so negative sign point_to_point_dist = -torch.sum(point_to_point_dist * weights, dim=1) / eps_denom( torch.sum(weights, dim=1)) return point_to_point_dist
def compute(self, point_clouds: PointClouds3D, points_filters=None, rebuild_knn=False, **kwargs): """ Args: point_clouds (optional) knn_tree: output from ops.knn_points excluding the query point itself (optional) knn_mask: mask valid knn results Returns: (P, N) """ self.sharpness_sigma = kwargs.get('sharpness_sigma', self.sharpness_sigma) self.filter_scale = kwargs.get('filter_scale', self.filter_scale) self.knn_tree = kwargs.get('knn_tree', self.knn_tree) self.knn_mask = kwargs.get('knn_mask', self.knn_mask) lengths = point_clouds.num_points_per_cloud() P_total = lengths.sum().item() points = point_clouds.points_padded() # - determine phi spatial with using local point spacing (i.e. 2*dist_to_nn) # - denoise normals # - determine w_normal # - mask out values outside ballneighbor i.e. d > filterSpatialScale * localPointSpacing # - projected distance dot(ni, x-xi) # - multiply and normalize the weights with torch.autograd.no_grad(): if rebuild_knn or self.knn_tree is None or self.knn_tree.idx.shape[: 2] != points.shape[: 2]: self._build_knn(point_clouds) phi = self.get_phi(point_clouds, **kwargs) # robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average # from neighboring normals Eq.(11) point_clouds = self._denoise_normals(point_clouds, phi, points_filters) # compute wn and wr # TODO(yifan): visibility weight? normal_w = self.get_normal_w(point_clouds, **kwargs) spatial_w = self.get_spatial_w(point_clouds, **kwargs) # update normals for a second iteration (?) Eq.(10) point_clouds = self._denoise_normals(point_clouds, phi * normal_w, points_filters) # compose weights weights = phi * spatial_w * normal_w weights[~self.knn_mask] = 0 # outside filter_scale*local_point_spacing weights mask_ball_query = self.knn_tree.dists > ( self.filter_scale * self.knn_tree.dists[:, :, :1] * 2.0) weights[mask_ball_query] = 0.0 # (B, P, k), dot product distance to surface # (we need to gather again because the normals have been changed in the denoising step) knn_normals = ops.knn_gather(point_clouds.normals_padded(), self.knn_tree.idx, lengths) # if points.requires_grad: # from DSS.core.rasterizer import _dbg_tensor # def save_grad(name): # def _save_grad(grad): # _dbg_tensor[name] = grad.detach().cpu() # return _save_grad # points.register_hook(save_grad('proj_grad')) dist_to_surface = torch.sum( (self.knn_tree.knn.detach() - points.unsqueeze(-2)) * knn_normals, dim=-1) if get_debugging_mode(): # points.requires_grad_(True) def save_grad(): lengths = point_clouds.num_points_per_cloud() def _save_grad(grad): dbg_tensor = get_debugging_tensor() if dbg_tensor is None: logger_py.error("dbg_tensor is None") if grad is None: logger_py.error('grad is None') # a dict of list of tensors dbg_tensor.pts_world_grad['proj'] = [ grad[b, :lengths[b]].detach().cpu() for b in range(grad.shape[0]) ] return _save_grad dbg_tensor = get_debugging_tensor() dbg_tensor.pts_world['proj'] = [ points[b, :lengths[b]].detach().cpu() for b in range(points.shape[0]) ] handle = points.register_hook(save_grad()) self.hooks.append(handle) # convert everything to packed weights = ops.padded_to_packed( weights, point_clouds.cloud_to_packed_first_idx(), P_total) dist_to_surface = ops.padded_to_packed( dist_to_surface, point_clouds.cloud_to_packed_first_idx(), P_total) # compute weighted signed distance to surface dist_to_surface = torch.sum(weights * dist_to_surface, dim=-1) / eps_denom( torch.sum(weights, dim=-1)) loss = dist_to_surface * dist_to_surface return loss
def estimate_normal(pc, k): with torch.no_grad(): # pc : [b, 3, n] b, _, n = pc.size() # get knn point set matrix inter_KNN = knn_points(pc.permute(0, 2, 1), pc.permute(0, 2, 1), K=k + 1) #[dists:[b,n,k+1], idx:[b,n,k+1]] nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute( 0, 3, 1, 2)[:, :, :, 1:].contiguous() # [b, 3, n ,k] # get covariance matrix and smallest eig-vector of individual point normal_vector = [] for i in range(b): if int(torch.__version__.split('.')[1]) >= 4: curr_point_set = nn_pts[i].detach().permute( 1, 0, 2) #curr_point_set:[n, 3, k] curr_point_set_mean = torch.mean( curr_point_set, dim=2, keepdim=True) #curr_point_set_mean:[n, 3, 1] curr_point_set = curr_point_set - curr_point_set_mean #curr_point_set:[n, 3, k] curr_point_set_t = curr_point_set.permute( 0, 2, 1) #curr_point_set_t:[n, k, 3] fact = 1.0 / (k - 1) cov_mat = fact * torch.bmm( curr_point_set, curr_point_set_t) #curr_point_set_t:[n, 3, 3] eigenvalue, eigenvector = torch.symeig( cov_mat, eigenvectors=True ) # eigenvalue:[n, 3], eigenvector:[n, 3, 3] persample_normal_vector = torch.gather( eigenvector, 2, torch.argmin( eigenvalue, dim=1).unsqueeze(1).unsqueeze(2).expand( n, 3, 1)).squeeze() #persample_normal_vector:[n, 3] #recorrect the direction via neighbour direction nbr_sum = curr_point_set.sum(dim=2) #curr_point_set:[n, 3] sign = -torch.sign( torch.bmm(persample_normal_vector.view(n, 1, 3), nbr_sum.view(n, 3, 1))).squeeze(2) persample_normal_vector = sign * persample_normal_vector normal_vector.append(persample_normal_vector.permute(1, 0)) else: persample_normal_vector = [] for j in range(n): curr_point_set = nn_pts[i, :, j, :].cpu() curr_point_set_np = curr_point_set.detach().numpy( ) #curr_point_set_np:[3,k] cov_mat_np = np.cov(curr_point_set_np) #cov_mat:[3,3] eigenvalue_np, eigenvector_np = np.linalg.eig( cov_mat_np ) #eigenvalue:[3], eigenvector:[3,3]; note that v[:,i] is the eigenvector corresponding to the eigenvalue w[i]. curr_normal_vector_np = torch.from_numpy( eigenvector_np[:, np.argmin(eigenvalue_np)] ) #curr_normal_vector:[3] persample_normal_vector.append(curr_normal_vector_np) persample_normal_vector = torch.stack(persample_normal_vector, 1) #recorrect the direction via neighbour direction nbr_sum = curr_point_set.sum(dim=1) #curr_point_set:[3] sign = -torch.sign( torch.bmm(persample_normal_vector.view(1, 3), nbr_sum.view(3, 1))).squeeze(1) persample_normal_vector = sign * persample_normal_vector normal_vector.append(persample_normal_vector.permute(1, 0)) normal_vector.append(persample_normal_vector) normal_vector = torch.stack(normal_vector, 0) #normal_vector:[b, 3, n] return normal_vector.float()
def compute(self, point_clouds: PointClouds3D, points_filter=None, rebuild_knn=False, **kwargs): """ Args: point_clouds (optional) knn_tree: output from ops.knn_points excluding the query point itself (optional) knn_mask: mask valid knn results Returns: (P, N) """ self.sharpness_sigma = kwargs.get('sharpness_sigma', self.sharpness_sigma) self.filter_scale = kwargs.get('filter_scale', self.filter_scale) self.knn_tree = kwargs.get('knn_tree', self.knn_tree) self.knn_mask = kwargs.get('knn_mask', self.knn_mask) lengths = point_clouds.num_points_per_cloud() P_total = lengths.sum().item() points = point_clouds.points_padded() # - determine phi spatial with using local point spacing (i.e. 2*dist_to_nn) # - denoise normals # - determine w_normal # - mask out values outside ballneighbor i.e. d > filterSpatialScale * localPointSpacing # - projected distance dot(ni, x-xi) # - multiply and normalize the weights with torch.autograd.no_grad(): if rebuild_knn or self.knn_tree is None or self.knn_tree.idx.shape[: 2] != points.shape[: 2]: self._build_knn(point_clouds) phi = self.get_phi(point_clouds, **kwargs) # robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average # from neighboring normals Eq.(11) point_clouds = self._denoise_normals(point_clouds, phi, points_filter, inplace=False) # compute wn and wr normal_w = self.get_normal_w(point_clouds, **kwargs) # visibility weight visibility_nb = ops.knn_gather( points_filter.visibility.unsqueeze(-1), self.knn_tree.idx, lengths) visibility_w = visibility_nb.float() visibility_w[~visibility_nb] = 0.1 # compose weights weights = phi * normal_w * visibility_w.squeeze(-1) # (B, P, k), dot product distance to surface knn_normals = ops.knn_gather(point_clouds.normals_padded(), self.knn_tree.idx, lengths) if get_debugging_mode(): # points.requires_grad_(True) def save_grad(): lengths = point_clouds.num_points_per_cloud() def _save_grad(grad): dbg_tensor = get_debugging_tensor() if dbg_tensor is None: logger_py.error("dbg_tensor is None") if grad is None: logger_py.error('grad is None') # a dict of list of tensors dbg_tensor.pts_world_grad['proj'] = [ grad[b, :lengths[b]].detach().cpu() for b in range(grad.shape[0]) ] return _save_grad if points.requires_grad: dbg_tensor = get_debugging_tensor() dbg_tensor.pts_world['proj'] = [ points[b, :lengths[b]].detach().cpu() for b in range(points.shape[0]) ] handle = points.register_hook(save_grad()) self.hooks.append(handle) sdf = torch.sum( (self.knn_tree.knn.detach() - points.unsqueeze(-2)) * knn_normals, dim=-1) # convert everything to packed weights = ops.padded_to_packed( weights, point_clouds.cloud_to_packed_first_idx(), P_total) sdf = ops.padded_to_packed(sdf, point_clouds.cloud_to_packed_first_idx(), P_total) # if get_debugging_mode(): # # save to dbg folder as normal # from ..utils.io import save_ply # save_ply('./dbg_repel_diff.ply', point_clouds.points_packed().cpu().detach(), normals=repel_vec.cpu().detach()) distance_to_face = sdf * sdf # compute weighted signed distance to surface loss = torch.sum(weights * distance_to_face, dim=-1) / eps_denom( torch.sum(weights, dim=-1)) return loss
def run(pointcloud_path, out_dir, decoder_type='siren', resume=True, **kwargs): """ test_implicit_siren_noisy_wNormals """ device = torch.device('cuda:0') if not os.path.exists(out_dir): os.makedirs(out_dir) # data points, normals = np.split(read_ply(pointcloud_path).astype('float32'), (3, ), axis=1) pmax, pmin = points.max(axis=0), points.min(axis=0) scale = (pmax - pmin).max() pcenter = (pmax + pmin) / 2 points = (points - pcenter) / scale * 1.5 scale_mat = scale_mat_inv = np.identity(4) scale_mat[[0, 1, 2], [0, 1, 2]] = 1 / scale * 1.5 scale_mat[[0, 1, 2], [3, 3, 3]] = -pcenter / scale * 1.5 scale_mat_inv = np.linalg.inv(scale_mat) normals = normals @ np.linalg.inv(scale_mat[:3, :3].T) object_bounding_sphere = np.linalg.norm(points, axis=1).max() pcl = trimesh.Trimesh(vertices=points, vertex_normals=normals, process=False) pcl.export(os.path.join(out_dir, "input_pcl.ply"), vertex_normal=True) assert (np.abs(points).max() < 1) dataset = torch.utils.data.TensorDataset(torch.from_numpy(points), torch.from_numpy(normals)) batch_size = 5000 data_loader = torch.utils.data.DataLoader( dataset, batch_size=batch_size, num_workers=1, shuffle=True, collate_fn=tolerating_collate, ) gt_surface_pts_all = torch.from_numpy(points).unsqueeze(0).float() gt_surface_normals_all = torch.from_numpy(normals).unsqueeze(0).float() gt_surface_normals_all = F.normalize(gt_surface_normals_all, dim=-1) if kwargs['use_off_normal_loss']: # subsample from pointset sub_idx = torch.randperm(gt_surface_normals_all.shape[1])[:20000] gt_surface_pts_sub = torch.index_select(gt_surface_pts_all, 1, sub_idx).to(device=device) gt_surface_normals_sub = torch.index_select(gt_surface_normals_all, 1, sub_idx).to(device=device) gt_surface_normals_sub = denoise_normals(gt_surface_pts_sub, gt_surface_normals_sub, neighborhood_size=30) if decoder_type == 'siren': decoder_params = { 'dim': 3, "out_dims": { 'sdf': 1 }, "c_dim": 0, "hidden_size": 256, 'n_layers': 3, "first_omega_0": 30, "hidden_omega_0": 30, "outermost_linear": True, } decoder = Siren(**decoder_params) # pretrained_model_file = os.path.join('data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format( # decoder_params['n_layers'], decoder_params['hidden_size'], decoder_params['first_omega_0'])) # loaded_state_dict = torch.load(pretrained_model_file) # decoder.load_state_dict(loaded_state_dict) elif decoder_type == 'sdf': decoder_params = { 'dim': 3, "out_dims": { 'sdf': 1 }, "c_dim": 0, "hidden_size": 512, 'n_layers': 8, 'bias': 1.0, } decoder = SDF(**decoder_params) else: raise ValueError print(decoder) decoder = decoder.to(device) # training total_iter = 30000 optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10000, 20000], gamma=0.5) shape = Shape(gt_surface_pts_all.cuda(), n_points=gt_surface_pts_all.shape[1] // 16, normals=gt_surface_normals_all.cuda()) # initialize siren with sphere_initialization checkpoint_io = CheckpointIO(out_dir, model=decoder, optimizer=optimizer) load_dict = dict() if resume: models_avail = [f for f in os.listdir(out_dir) if f[-3:] == '.pt'] if len(models_avail) > 0: models_avail.sort() load_dict = checkpoint_io.load(models_avail[-1]) it = load_dict.get('it', 0) if it > 0: try: iso_point_files = [ f for f in os.listdir(out_dir) if f[-7:] == 'iso.ply' ] iso_point_iters = [ int(os.path.basename(f[:-len('_iso.ply')])) for f in iso_point_files ] iso_point_iters = np.array(iso_point_iters) idx = np.argmax(iso_point_iters[(iso_point_iters - it) <= 0]) iso_point_file = np.array(iso_point_files)[(iso_point_iters - it) <= 0][idx] iso_points = torch.from_numpy( read_ply(os.path.join(out_dir, iso_point_file))[..., :3]) shape.points = iso_points.to(device=shape.points.device).view( 1, -1, 3) print('Loaded iso-points from %s' % iso_point_file) except Exception as e: pass # loss eikonal_loss = NormalLengthLoss(reduction='mean') # start training # save_ply(os.path.join(out_dir, 'in_iso_points.ply'), (to_homogen(shape.points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3)) save_ply(os.path.join(out_dir, 'in_iso_points.ply'), shape.points.cpu().view(-1, 3)) # autograd.set_detect_anomaly(True) iso_points = shape.points iso_points_normal = None while True: if (it > total_iter): checkpoint_io.save('model_{:04d}.pt'.format(it), it=it) mesh = get_surface_high_res_mesh( lambda x: decoder(x).sdf.squeeze(), resolution=512) mesh.apply_transform(scale_mat_inv) mesh.export(os.path.join(out_dir, "final.ply")) break for batch in data_loader: gt_surface_pts, gt_surface_normals = batch gt_surface_pts.unsqueeze_(0) gt_surface_normals.unsqueeze_(0) gt_surface_pts = gt_surface_pts.to(device=device).detach() gt_surface_normals = gt_surface_normals.to(device=device).detach() optimizer.zero_grad() decoder.train() loss = defaultdict(float) lambda_surface_sdf = 1e3 lambda_surface_normal = 1e2 if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']: lambda_surface_sdf = kwargs['lambda_surface_sdf'] lambda_surface_normal = kwargs['lambda_surface_normal'] # debug if (it - kwargs['warm_up']) % 1000 == 0: # generate iso surface with torch.autograd.no_grad(): box_size = (object_bounding_sphere * 2 + 0.2, ) * 3 imgs = plot_cuts( lambda x: decoder(x).sdf.squeeze().detach(), box_size=box_size, max_n_eval_pts=10000, thres=0.0, imgs_per_cut=1, save_path=os.path.join(out_dir, '%010d_iso.html' % it)) mesh = get_surface_high_res_mesh( lambda x: decoder(x).sdf.squeeze(), resolution=200) mesh.apply_transform(scale_mat_inv) mesh.export(os.path.join(out_dir, '%010d_mesh.ply' % it)) if it % 2000 == 0: checkpoint_io.save('model.pt', it=it) pred_surface_grad = gradient(gt_surface_pts.clone(), lambda x: decoder(x).sdf) # every once in a while update shape and points # sample points in space and on the shape # use iso points to weigh data points loss weights = 1.0 if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']: if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and ( it - kwargs['warm_up']) % kwargs['resample_every'] == 0: # if shape.points.shape[1]/iso_points.shape[1] < 1.0: # idx = fps(iso_points.view(-1,3), torch.zeros(iso_points.shape[1], dtype=torch.long, device=iso_points.device), shape.points.shape[1]/iso_points.shape[1]) # iso_points = iso_points.view(-1,3)[idx].view(1,-1,3) iso_points = shape.get_iso_points( iso_points + 0.1 * (torch.rand_like(iso_points) - 0.5), decoder, ear=kwargs['ear'], outlier_tolerance=kwargs['outlier_tolerance']) # iso_points = shape.get_iso_points(shape.points, decoder, ear=kwargs['ear'], outlier_tolerance=kwargs['outlier_tolerance']) iso_points_normal = estimate_pointcloud_normals( iso_points.view(1, -1, 3), 8, False) if kwargs['denoise_normal']: iso_points_normal = denoise_normals(iso_points, iso_points_normal, num_points=None) iso_points_normal = iso_points_normal.view_as( iso_points) elif iso_points_normal is None: iso_points_normal = estimate_pointcloud_normals( iso_points.view(1, -1, 3), 8, False) # iso_points = resample_uniformly(iso_points.view(1,-1,3)) # TODO: use gradient from network or neighborhood? iso_points_g = gradient(iso_points.clone(), lambda x: decoder(x).sdf) if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and ( it - kwargs['warm_up']) % kwargs['resample_every'] == 0: # save_ply(os.path.join(out_dir, '%010d_iso.ply' % it), (to_homogen(iso_points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3), normals=iso_points_g.view(-1,3).detach().cpu()) save_ply(os.path.join(out_dir, '%010d_iso.ply' % it), iso_points.cpu().detach().view(-1, 3), normals=iso_points_g.view(-1, 3).detach().cpu()) if kwargs['weight_mode'] == 1: weights = get_iso_bilateral_weights( gt_surface_pts, gt_surface_normals, iso_points, iso_points_g).detach() elif kwargs['weight_mode'] == 2: weights = get_laplacian_weights(gt_surface_pts, gt_surface_normals, iso_points, iso_points_g).detach() elif kwargs['weight_mode'] == 3: weights = get_heat_kernel_weights(gt_surface_pts, gt_surface_normals, iso_points, iso_points_g).detach() if (it - kwargs['warm_up'] ) % 1000 == 0 and kwargs['weight_mode'] != -1: print("min {:.4g}, max {:.4g}, std {:.4g}, mean {:.4g}". format(weights.min(), weights.max(), weights.std(), weights.mean())) colors = scaler_to_color(1 - weights.view(-1).cpu().numpy(), cmap='Reds') save_ply( os.path.join(out_dir, '%010d_batch_weight.ply' % it), (to_homogen(gt_surface_pts).cpu().detach().numpy() @ scale_mat_inv.T)[..., :3].reshape(-1, 3), colors=colors) sample_idx = torch.randperm( iso_points.shape[1])[:min(batch_size, iso_points.shape[1])] iso_points_sampled = iso_points.detach()[:, sample_idx, :] # iso_points_sampled = iso_points.detach() iso_points_sdf = decoder(iso_points_sampled.detach()).sdf loss_iso_points_sdf = iso_points_sdf.abs().mean( ) * kwargs['lambda_iso_sdf'] * iso_points_sdf.nelement() / ( iso_points_sdf.nelement() + 8000) loss['loss_sdf_iso'] = loss_iso_points_sdf.detach() loss['loss'] += loss_iso_points_sdf # TODO: predict iso_normals from local_frame iso_normals_sampled = iso_points_normal.detach()[:, sample_idx, :] iso_g_sampled = iso_points_g[:, sample_idx, :] loss_normals = torch.mean( (1 - F.cosine_similarity( iso_normals_sampled, iso_g_sampled, dim=-1).abs()) ) * kwargs['lambda_iso_normal'] * iso_points_sdf.nelement() / ( iso_points_sdf.nelement() + 8000) # loss_normals = torch.mean((1 - F.cosine_similarity(iso_points_normal, iso_points_g, dim=-1).abs())) * kwargs['lambda_iso_normal'] loss['loss_normal_iso'] = loss_normals.detach() loss['loss'] += loss_normals idx = torch.randperm(gt_surface_pts.shape[1]).to( device=gt_surface_pts.device)[:(gt_surface_pts.shape[1] // 2)] tmp = torch.index_select(gt_surface_pts, 1, idx) space_pts = torch.cat([ torch.rand_like(tmp) * 2 - 1, torch.randn_like(tmp, device=tmp.device, dtype=tmp.dtype) * 0.1 + tmp ], dim=1) space_pts.requires_grad_(True) pred_space_sdf = decoder(space_pts).sdf pred_space_grad = torch.autograd.grad( pred_space_sdf, [space_pts], [torch.ones_like(pred_space_sdf)], create_graph=True)[0] # 1. eikonal term loss_eikonal = ( eikonal_loss(pred_surface_grad) + eikonal_loss(pred_space_grad)) * kwargs['lambda_eikonal'] loss['loss_eikonal'] = loss_eikonal.detach() loss['loss'] += loss_eikonal # 2. SDF loss # loss on iso points pred_surface_sdf = decoder(gt_surface_pts).sdf loss_sdf = torch.mean( weights * pred_surface_sdf.abs()) * lambda_surface_sdf if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[ 'lambda_iso_sdf'] != 0: # loss_sdf = 0.5 * loss_sdf loss_sdf = loss_sdf * pred_surface_sdf.nelement() / ( pred_surface_sdf.nelement() + iso_points_sdf.nelement()) if kwargs['use_sal_loss'] and iso_points is not None: dists, idxs, _ = knn_points(space_pts.view(1, -1, 3), iso_points.view(1, -1, 3).detach(), K=1) dists = dists.view_as(pred_space_sdf) idxs = idxs.view_as(pred_space_sdf) loss_inter = ((eps_sqrt(dists).sqrt() - pred_space_sdf.abs())** 2).mean() * kwargs['lambda_inter_sal'] else: alpha = (it / total_iter + 1) * 100 loss_inter = torch.exp( -alpha * pred_space_sdf.abs()).mean() * kwargs['lambda_inter_sdf'] loss_sald = torch.tensor(0.0).cuda() # prevent wrong closing for open mesh if kwargs['use_off_normal_loss'] and it < 1000: dists, idxs, _ = knn_points(space_pts.view(1, -1, 3), gt_surface_pts_sub.view(1, -1, 3).cuda(), K=1) knn_normal = knn_gather( gt_surface_normals_sub.cuda().view(1, -1, 3), idxs).view(1, -1, 3) direction_correctness = -F.cosine_similarity( knn_normal, pred_space_grad, dim=-1) direction_correctness[direction_correctness < 0] = 0 loss_sald = torch.mean( direction_correctness * torch.exp(-2 * dists)) * 2 # 3. normal direction loss_normals = torch.mean(weights * (1 - F.cosine_similarity( gt_surface_normals, pred_surface_grad, dim=-1)) ) * lambda_surface_normal if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[ 'lambda_iso_normal'] != 0: # loss_normals = 0.5 * loss_normals loss_normals = loss_normals * gt_surface_normals.nelement() / ( gt_surface_normals.nelement() + iso_normals_sampled.nelement()) loss['loss_sdf'] = loss_sdf.detach() loss['loss_inter'] = loss_inter.detach() loss['loss_normals'] = loss_normals.detach() loss['loss_sald'] = loss_sald loss['loss'] += loss_sdf loss['loss'] += loss_inter loss['loss'] += loss_sald loss['loss'] += loss_normals loss['loss'].backward() torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.) optimizer.step() scheduler.step() if it % 20 == 0: print("iter {:05d} {}".format( it, ', '.join([ '{}: {}'.format(k, v.item()) for k, v in loss.items() ]))) it += 1
def compute(self, point_clouds: PointClouds3D, points_filter=None, rebuild_knn=True, **kwargs): self.knn_tree = kwargs.get('knn_tree', self.knn_tree) self.knn_mask = kwargs.get('knn_mask', self.knn_mask) lengths = point_clouds.num_points_per_cloud() P_total = lengths.sum().item() points_padded = point_clouds.points_padded() if not points_padded.requires_grad: logger_py.warn( 'Computing repulsion loss, but points_padded is not differentiable.' ) # Compute necessary weights to project points to local plane # TODO(yifan): This part is same as ProjectionLoss # how can we at best save repetitive computation with torch.autograd.no_grad(): if rebuild_knn or self.knn_tree is None or points_padded.shape[: 2] != self.knn_tree.shape[: 2]: self._build_knn(point_clouds) phi = self.get_phi(point_clouds, **kwargs) point_clouds = self._denoise_normals(point_clouds, phi, points_filter, inplace=False) # project the point to a local surface knn_diff = points_padded.unsqueeze(-2) - self.knn_tree.knn.detach() knn_normals = ops.knn_gather(point_clouds.normals_padded(), self.knn_tree.idx, lengths) pts_diff_proj = knn_diff - \ (knn_diff * knn_normals).sum(dim=-1, keepdim=True) * knn_normals if get_debugging_mode(): # points_padded.requires_grad_(True) def save_grad(): lengths = point_clouds.num_points_per_cloud() def _save_grad(grad): dbg_tensor = get_debugging_tensor() if dbg_tensor is None: logger_py.error("dbg_tensor is None") if grad is None: logger_py.error('grad is None') # a dict of list of tensors dbg_tensor.pts_world_grad['repel'] = [ grad[b, :lengths[b]].detach().cpu() for b in range(grad.shape[0]) ] return _save_grad if points_padded.requires_grad: dbg_tensor = get_debugging_tensor() dbg_tensor.pts_world['repel'] = [ points_padded[b, :lengths[b]].detach().cpu() for b in range(points_padded.shape[0]) ] handle = points_padded.register_hook(save_grad()) self.hooks.append(handle) with torch.autograd.no_grad(): spatial_w = self.get_spatial_w(point_clouds, **kwargs) # set far neighbors' spatial_w to 0 normal_w = self.get_normal_w(point_clouds, **kwargs) density_w = torch.sum(spatial_w, dim=-1, keepdim=True) + 1.0 weights = spatial_w * normal_w # convert everything to packed weights = ops.padded_to_packed( weights, point_clouds.cloud_to_packed_first_idx(), P_total) pts_diff_proj = ops.padded_to_packed( pts_diff_proj.contiguous().view(pts_diff_proj.shape[0], pts_diff_proj.shape[1], -1), point_clouds.cloud_to_packed_first_idx(), P_total).view(P_total, -1, 3) density_w = ops.padded_to_packed( density_w, point_clouds.cloud_to_packed_first_idx(), P_total) # we want to maximize this, so negative sign repel_vec = torch.sum(pts_diff_proj * weights.unsqueeze(-1), dim=1) / eps_denom( torch.sum(weights, dim=1).unsqueeze(-1)) repel_vec = repel_vec * density_w loss = torch.exp(-repel_vec.abs()) # if get_debugging_mode(): # # save to dbg folder as normal # from ..utils.io import save_ply # save_ply('./dbg_repel_diff.ply', point_clouds.points_packed().cpu().detach(), normals=repel_vec.cpu().detach()) return loss
def estimate_perpendicular(pc, k, sigma=0.01, clip=0.05): with torch.no_grad(): # pc : [b, 3, n] b, _, n = pc.size() inter_KNN = knn_points(pc.permute(0, 2, 1), pc.permute(0, 2, 1), K=k + 1) #[dists:[b,n,k+1], idx:[b,n,k+1]] nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute( 0, 3, 1, 2)[:, :, :, 1:].contiguous() # [b, 3, n ,k] # get covariance matrix and smallest eig-vector of individual point perpendi_vector_1 = [] perpendi_vector_2 = [] for i in range(b): curr_point_set = nn_pts[i].detach().permute( 1, 0, 2) #curr_point_set:[n, 3, k] curr_point_set_mean = torch.mean( curr_point_set, dim=2, keepdim=True) #curr_point_set_mean:[n, 3, 1] curr_point_set = curr_point_set - curr_point_set_mean #curr_point_set:[n, 3, k] curr_point_set_t = curr_point_set.permute( 0, 2, 1) #curr_point_set_t:[n, k, 3] fact = 1.0 / (k - 1) cov_mat = fact * torch.bmm( curr_point_set, curr_point_set_t) #curr_point_set_t:[n, 3, 3] eigenvalue, eigenvector = torch.symeig( cov_mat, eigenvectors=True) # eigenvalue:[n, 3], eigenvector:[n, 3, 3] larger_dim_idx = torch.topk(eigenvalue, 2, dim=1, largest=True, sorted=False, out=None)[1] # eigenvalue:[n, 2] persample_perpendi_vector_1 = torch.gather( eigenvector, 2, larger_dim_idx[:, 0].unsqueeze(1).unsqueeze(2).expand( n, 3, 1)).squeeze() #persample_perpendi_vector_1:[n, 3] persample_perpendi_vector_2 = torch.gather( eigenvector, 2, larger_dim_idx[:, 1].unsqueeze(1).unsqueeze(2).expand( n, 3, 1)).squeeze() #persample_perpendi_vector_2:[n, 3] perpendi_vector_1.append(persample_perpendi_vector_1.permute(1, 0)) perpendi_vector_2.append(persample_perpendi_vector_2.permute(1, 0)) perpendi_vector_1 = torch.stack(perpendi_vector_1, 0) #perpendi_vector_1:[b, 3, n] perpendi_vector_2 = torch.stack(perpendi_vector_2, 0) #perpendi_vector_1:[b, 3, n] aux_vector1 = sigma * torch.randn( b, n).unsqueeze(1).cuda() #aux_vector1:[b, 1, n] aux_vector2 = sigma * torch.randn( b, n).unsqueeze(1).cuda() #aux_vector2:[b, 1, n] return torch.clamp(perpendi_vector_1 * aux_vector1, -1 * clip, clip) + torch.clamp( perpendi_vector_2 * aux_vector2, -1 * clip, clip)
class Trainer(BaseTrainer): def __init__(self, model, optimizer, scheduler, generator, train_loader, val_loader, device='cpu', cameras=None, log_dir=None, vis_dir=None, debug_dir=None, val_dir=None, threshold=0.0, n_training_points=2048, n_eval_points=4000, lambda_occupied=1., lambda_freespace=1., lambda_rgb=1., lambda_eikonal=0.01, patch_size=1, clip_grad=True, reduction_method='sum', sample_continuous=False, overwrite_visualization=True, n_debug_points=-1, saliency_sampling_3d=False, resample_every=-1, refresh_metric_every=-1, gamma_n_points_dss=2.0, gamma_n_rays=0.6, gamma_lambda_rgb=1.0, steps_n_points_dss=-1, steps_n_rays=-1, steps_lambda_rgb=-1, limit_n_points_dss=24000, limit_n_rays=1024, steps_proj_tolerance=-1, gamma_proj_tolerance=0.5, limit_proj_tolerance=5e-5, steps_lambda_sdf=-1, gamma_lambda_sdf=1.0, warm_up_iters=0, sdf_alpha=5.0, limit_sdf_alpha=100, gamma_sdf_alpha=2, steps_sdf_alpha=-1, limit_lambda_freespace=1.0, limit_lambda_occupied=1.0, limit_lambda_rgb=1.0, **kwargs): """Initialize the BaseModel class. Args: model (nn.Module) optimizer: optimizer scheduler: scheduler device: device """ self.cfg = kwargs self.device = device self.model = model self.cameras = cameras self.val_loader = val_loader self.train_loader = train_loader self.tb_logger = SummaryWriter( log_dir + datetime.datetime.now().strftime("-%Y%m%d-%H%M%S")) # implicit function model self.vis_dir = vis_dir self.val_dir = val_dir self.threshold = threshold self.lambda_eikonal = lambda_eikonal self.lambda_occupied = lambda_occupied self.lambda_freespace = lambda_freespace self.lambda_rgb = lambda_rgb self.sdf_alpha = sdf_alpha self.generator = generator self.n_eval_points = n_eval_points self.patch_size = patch_size self.reduction_method = reduction_method self.sample_continuous = sample_continuous self.overwrite_visualization = overwrite_visualization self.saliency_sampling_3d = saliency_sampling_3d self.resample_every = resample_every self.refresh_metric_every = refresh_metric_every self.warm_up_iters = warm_up_iters # tuple (score, mesh) self._mesh_cache = None self._pcl_cache = {} self.ref_pcl = None n_points_per_cloud = 0 init_proj_tolerance = 0 if isinstance(self.model, CombinedModel): n_points_per_cloud = self.model.n_points_per_cloud init_proj_tolerance = self.model.projection.proj_tolerance self.training_scheduler = TrainerScheduler( init_n_points_dss=n_points_per_cloud, init_n_rays=n_training_points, init_proj_tolerance=init_proj_tolerance, init_lambda_rgb=lambda_rgb, init_lambda_freespace=lambda_freespace, init_lambda_occupied=lambda_occupied, init_sdf_alpha=self.sdf_alpha, steps_n_points_dss=steps_n_points_dss, steps_n_rays=steps_n_rays, steps_proj_tolerance=steps_proj_tolerance, steps_sdf_alpha=steps_sdf_alpha, steps_lambda_rgb=steps_lambda_rgb, steps_lambda_sdf=steps_lambda_sdf, warm_up_iters=self.warm_up_iters, gamma_n_points_dss=gamma_n_points_dss, gamma_n_rays=gamma_n_rays, gamma_proj_tolerance=gamma_proj_tolerance, gamma_lambda_rgb=gamma_lambda_rgb, gamma_sdf_alpha=gamma_sdf_alpha, gamma_lambda_sdf=gamma_lambda_sdf, limit_n_points_dss=limit_n_points_dss, limit_n_rays=limit_n_rays, limit_proj_tolerance=limit_proj_tolerance, limit_sdf_alpha=limit_sdf_alpha, limit_lambda_rgb=limit_lambda_rgb, limit_lambda_occupied=limit_lambda_occupied, limit_lambda_freespace=limit_lambda_freespace) self.debug_dir = debug_dir self.hooks = [] self.n_training_points = n_training_points self.n_debug_points = n_debug_points self.optimizer = optimizer self.scheduler = scheduler self.clip_grad = clip_grad self.iou_loss = IouLoss(reduction=self.reduction_method, channel_dim=None) self.eikonal_loss = NormalLengthLoss(reduction=self.reduction_method) self.l1_loss = L1Loss(reduction=self.reduction_method) self.l2_loss = L2Loss(reduction=self.reduction_method) self.sdf_loss = SDF2DLoss(reduction=self.reduction_method) def _query_mesh(self): """ Generate mesh at the current training (it), evaluate """ if self._mesh_cache is None: try: mesh = self.generator.generate_mesh({}, with_colors=False, with_normals=False) except Exception as e: return logger_py.error("Couldn\'t generate mesh {}".format(e)) else: if self.cfg['model_selection_mode'] == 'maximize': model_selection_sign = 1 elif self.cfg['model_selection_mode'] == 'minimize': model_selection_sign = -1 self._mesh_cache = (-model_selection_sign * float('inf'), mesh) self._pcl_cache.clear() if self._mesh_cache is not None: return self._mesh_cache[1] def _query_pcl(self, n_points=-1): """ get a uniform point cloud on the iso-surface, save in a cache """ if n_points < 0 or n_points is None and len(self._pcl_cache) > 0: n_points = list(self._pcl_cache.keys())[0] iso_pcl = self._pcl_cache.get(n_points, None) new_pcl = False if iso_pcl is None: t0 = time.time() iso_pcl = sample_uniform_iso_points( self.model.decoder, n_points, bounding_sphere_radius=self.model.object_bounding_sphere, init_points=self.model._points.points_padded()) t1 = time.time() normals = self.model.get_normals_from_grad(iso_pcl.points_padded(), requires_grad=False) logger_py.debug('[Sample from Mesh] time ellapsed {}s'.format(t1 - t0)) iso_pcl = PointClouds3D(iso_pcl.points_padded(), normals=normals) self._pcl_cache.clear() self._pcl_cache[n_points] = iso_pcl new_pcl = True return new_pcl, iso_pcl def evaluate_mesh(self, val_dataloader, it, **kwargs): logger_py.info("[Mesh Evaluation]") t0 = time.time() if not os.path.exists(self.val_dir): os.makedirs(self.val_dir) eval_list = defaultdict(list) mesh_gt = val_dataloader.dataset.get_meshes() assert (mesh_gt is not None) mesh_gt = mesh_gt.to(device=self.device) pointcloud_tgt = val_dataloader.dataset.get_pointclouds( num_points=self.n_eval_points) mesh = self.generator.generate_mesh({}, with_colors=False, with_normals=False) points_pred = trimesh.sample.sample_surface_even( mesh, pointcloud_tgt.points_packed().shape[0]) chamfer_dist = chamfer_distance( pointcloud_tgt.points_padded(), torch.from_numpy(points_pred).view(1, -1, 3).to( device=pointcloud_tgt.points_padded().device, dtype=torch.float32)) eval_dict_mesh = {'chamfer': chamfer_dist.item()} # save to "val" dict t1 = time.time() logger_py.info('[Mesh Evaluation] time ellapsed {}s'.format(t1 - t0)) if not mesh.is_empty: mesh.export(os.path.join(self.val_dir, "%010d.ply" % it)) return eval_dict_mesh def eval_step(self, data, **kwargs): """ evaluate with image mask iou or image rgb psnr """ lights_model = kwargs.get('lights', self.val_loader.dataset.get_lights()) cameras_model = kwargs.get('cameras', self.val_loader.dataset.get_cameras()) img_size = self.generator.img_size eval_dict = {'iou': 0.0, 'psnr': 0.0} with autograd.no_grad(): self.model.eval() data = self.process_data_dict(data, cameras_model, lights=lights_model) img_mask = data['mask_img'] img = data['img'] # render image rgbas = self.generator.raytrace_images(img_size, img_mask, cameras=data['camera'], lights=data['light']) assert (len(rgbas) == 1) rgba = rgbas[0] rgba = torch.tensor(rgba[None, ...], dtype=torch.float, device=img_mask.device).permute(0, 3, 1, 2) # compare iou mask_gt = F.interpolate(img_mask.float(), img_size, mode='bilinear', align_corners=False).squeeze(1) mask_pred = rgba[:, 3, :, :] eval_dict['iou'] += self.iou_loss(mask_gt.float(), mask_pred.float(), reduction='mean') # compare psnr rgb_gt = F.interpolate(img, img_size, mode='bilinear', align_corners=False) rgb_pred = rgba[:, :3, :, :] eval_dict['psnr'] += self.l2_loss(rgb_gt, rgb_pred, channel_dim=1, reduction='mean', align_corners=False).detach() return eval_dict def train_step(self, data, cameras, **kwargs): """ Args: data (dict): contains img, img.mask and img.depth and camera_mat cameras (Cameras): Cameras object from pytorch3d Returns: loss """ self.model.train() self.optimizer.zero_grad() it = kwargs.get("it", None) lights = kwargs.get('lights', None) if hasattr(self, 'training_scheduler'): self.training_scheduler.step(self, it) if isinstance(self.model, CombinedModel) and it > self.warm_up_iters: if self.resample_every > 0 and ( it - self.warm_up_iters) % self.resample_every == 0: self._pcl_cache.clear() is_new = self.sample_from_mesh(self.model.n_points_per_cloud) if self.saliency_sampling_3d: refresh_per_point_metric = is_new or (self.refresh_metric_every > 0) and \ ((it - 1) % self.refresh_metric_every == 0) self.model.eval() if refresh_per_point_metric: self.ref_pcl = self.ref_per_point_metric( mode=self.cfg['ref_metric']) colors = scaler_to_color( self.ref_pcl.features_packed().cpu().numpy().reshape( -1)) save_ply( os.path.join(self.vis_dir, '%010d_refpcl.ply' % it), self.ref_pcl.points_packed().cpu().numpy(), colors=colors, normals=self.ref_pcl.normals_packed().cpu().numpy()) data = self.process_data_dict(data, cameras, lights=lights) self.model.train() # autograd.set_detect_anomaly(True) loss = self.compute_loss(data['img'], data['mask_img'], data['input'], data['camera'], data['light'], it=it, ref_pcl=self.ref_pcl) loss.backward() if self.clip_grad: torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.) self.optimizer.step() check_weights(self.model.state_dict()) return loss.item() def process_data_dict(self, data, cameras, lights=None): ''' Processes the data dictionary and returns respective tensors Args: data (dictionary): data dictionary ''' device = self.device # Get "ordinary" data img = data.get('img.rgb').to(device) assert (img.min() >= 0 and img.max() <= 1 ), "Image must be a floating number between 0 and 1." mask_img = data.get('img.mask').to(device) camera_mat = data.get('camera_mat', None) # inputs for SVR inputs = data.get('inputs', torch.empty(0, 0)).to(device) # set camera matrix to cameras if camera_mat is None: logger_py.warning( "Camera matrix is not provided! Using the default matrix") else: cameras.R, cameras.T = decompose_to_R_and_t(camera_mat) cameras._N = cameras.R.shape[0] cameras.to(device) if lights is not None: lights_params = data.get('lights', None) if lights_params is not None: lights = type(lights)(**lights_params).to(device) return { 'img': img, 'mask_img': mask_img, 'input': inputs, 'camera': cameras, 'light': lights } def sample_pixels(self, n_rays: int, batch_size: int, h: int, w: int): if n_rays >= h * w: p = arange_pixels((h, w), batch_size)[1].to(self.device) else: p = sample_patch_points( batch_size, n_rays, patch_size=self.patch_size, image_resolution=(h, w), continuous=self.sample_continuous, ).to(self.device) return p def sample_from_mesh(self, n_points): """ Construct mesh from implicit model and sample from the mesh to get iso-surface points, which is used for projection in the combined model Return True is a new pcl is sampled """ try: new_pcl, pcl = self._query_pcl(n_points) self.model._points = pcl self.model.points = pcl.points_padded() if not os.path.exists(self.vis_dir): os.makedirs(self.vis_dir) except Exception as e: logger_py.error("Couldn't sample points from mesh: {}".format(e)) return False return new_pcl def compute_loss(self, img, mask_img, inputs, cameras, lights, n_points=None, eval_mode=False, it=None, ref_pcl=None): ''' Compute the loss. Args: data (dict): data dictionary eval_mode (bool): whether to use eval mode it (int): training iteration ''' # Initialize loss dictionary and other values loss = {} # overwrite n_points if n_points is None: n_points = self.n_eval_points if eval_mode else self.n_training_points # Shortcuts device = self.device patch_size = self.patch_size reduction_method = self.reduction_method batch_size, _, h, w = img.shape # Assertions assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0)) # Apply losses # 1.) Initialize loss loss['loss'] = 0 if isinstance(self.model, CombinedModel): # 1.) Sample points on image plane ("pixels") p = None if n_points > 0: p = self.sample_pixels(n_points, batch_size, h, w) project = (it - self.warm_up_iters > 0) and not eval_mode sample_iso_offsurface = project # TODO: check insertion fix: start using insertion after 5000 iterations saliency_sampling_3d = self.saliency_sampling_3d model_outputs = self.model( mask_img, img, cameras, pixels=p, inputs=inputs, lights=lights, it=it, eval_mode=eval_mode, project=project, sample_iso_offsurface=sample_iso_offsurface, proj_kwargs={ 'ref_pcl': ref_pcl if saliency_sampling_3d else None, 'insert': saliency_sampling_3d }) else: # 1.) Sample points on image plane ("pixels") p = self.sample_pixels(n_points, batch_size, h, w) model_outputs = self.model(p, img, mask_img, inputs=inputs, cameras=cameras, lights=lights, it=it) point_clouds = model_outputs.get('iso_pcl') rgb_gt = model_outputs.get('iso_rgb_gt') sdf_freespace = model_outputs.get('sdf_freespace') sdf_occupancy = model_outputs.get('sdf_occupancy') if it % 50 == 0 and not eval_mode and not get_debugging_mode(): logger_py.debug('# iso: {}, # occ_off: {}, # free_off: {}'.format( model_outputs['iso_rgb_gt'].shape[0], model_outputs['p_occupancy'].shape[0], model_outputs['p_freespace'].shape[0])) if not point_clouds.isempty(): # Photo Consistency Loss normalizing_value = 1.0 if reduction_method == 'sum': total_p = batch_size * self.training_scheduler.init_n_rays normalizing_value = 1.0 / total_p * min( (sdf_freespace.nelement() + sdf_occupancy.nelement()) / float(rgb_gt.size(0)), 1.0) self.calc_photoconsistency_loss( point_clouds, rgb_gt, reduction_method, loss, normalizing_value=normalizing_value) # Occupancy and Freespace losses if self.lambda_occupied > 0 or self.lambda_freespace > 0: normalizing_value = 1.0 if reduction_method == 'sum': total_p = batch_size * self.training_scheduler.init_n_rays normalizing_value = 1.0 / total_p self.calc_sdf_mask_loss(sdf_freespace, sdf_occupancy, self.sdf_alpha, reduction_method, loss, normalizing_value=normalizing_value) # Eikonal loss # Random samples in the space total_p = batch_size * self.training_scheduler.init_n_rays space_pts = torch.empty(total_p, 3).uniform_( -self.model.object_bounding_sphere, self.model.object_bounding_sphere).to(device=device) # space_pts = torch.cat([model_outputs.get('p_freespace'), model_outputs.get('p_occupancy'), point_clouds.points_packed()]).detach() eikonal_normals = self.model.get_normals_from_grad(space_pts, c=inputs, requires_grad=True) normalizing_value = 1.0 if reduction_method == 'sum': normalizing_value = 1.0 / total_p self.calc_eikonal_loss(eikonal_normals, reduction_method, loss, normalizing_value=normalizing_value) for k, v in loss.items(): mode = 'val' if eval_mode else 'train' if isinstance(v, torch.Tensor): self.tb_logger.add_scalar('%s/%s' % (mode, k), v.item(), it) else: self.tb_logger.add_scalar('%s/%s' % (mode, k), v, it) return loss if eval_mode else loss['loss'] def ref_per_point_metric(self, ref_pcl: PointClouds3D = None, mode='curvature'): """ Computes the metric used for sampling or weighting, e.g. curvature / loss, and update to pcl's features When mode == 'curvature', estimates the shape curvature from the point cloud using a local neighborhood of 12 points, When mode == 'loss', average the per point per view RGB loss over the entire training images """ if ref_pcl is None: ref_pcl = self.model._points if (n_points := ref_pcl.points_packed().shape[0]) > 5000: ref_pcl = farthest_sampling(ref_pcl, 5000 / n_points) with autograd.no_grad(): if mode == 'loss': self.model.eval() cameras = self.val_loader.dataset.get_cameras() lights = self.val_loader.dataset.get_lights() # project all init_points to the surface proj_result = self.model.projection.project_points( ref_pcl, self.model.decoder, skip_resampling=False, skip_upsampling=False, sample_iters=2) ref_pcl = PointClouds3D( proj_result['levelset_points'][proj_result['mask']].view( 1, -1, 3), normals=proj_result['levelset_normals'][ proj_result['mask']].view(1, -1, 3)) num_points2 = ref_pcl.num_points_per_cloud() logger_py.info( '[Per Point Loss Metric] evaluating ref point cloud ({}) on all training images' .format(num_points2.item())) assert (len(num_points2) == 1) from ..utils.mathHelper import RunningStat runStat = RunningStat(num_points2.item(), device=ref_pcl.device) # set max_iso_per_batch to -1 max_iso_per_batch = self.model.max_iso_per_batch self.model.max_iso_per_batch = -1 # If loss is used, transfer the computed loss in the current point cloud to the reference point cloud for batch in tqdm(self.val_loader): data = self.process_data_dict(batch, cameras=cameras, lights=lights) mask_img, img, cameras, lights = data['mask_img'], data[ 'img'], data['camera'], data['light'] # set proj_max_iters to 0 because we already projected the points before hand model_outputs = self.model( mask_img, img, cameras, mask_gt=None, pixels=None, inputs=None, lights=lights, project=True, sample_iso_offsurface=False, ) point_clouds = model_outputs['iso_pcl'] pixel_pred = model_outputs['iso_pixel'] rgb_gt = model_outputs['iso_rgb_gt'] sig = num_points2.item( ) / self.model.object_bounding_sphere dist_thres = 4 / sig if not point_clouds.isempty(): num_points1 = point_clouds.num_points_per_cloud().sum( 0, keepdim=True) loss = {'loss': 0.0} self.calc_photoconsistency_loss( point_clouds, rgb_gt, 'none', loss, ) loss_per_point = loss['loss_rgb'] / self.lambda_rgb query_points = point_clouds.points_packed() dists, idxs, _ = knn_points( ref_pcl.points_padded().contiguous(), query_points.unsqueeze(0).contiguous(), num_points2, num_points1, K=1, return_nn=False) loss_ref_point = knn_gather( loss_per_point.view(1, num_points1.item(), 1), idxs, num_points1).view(-1, 1) mask = (dists < dist_thres) & (dists > 0) runStat.add(loss_ref_point.view(-1), mask.view(-1)) per_point_metric = runStat.mean().view(-1, 1) logger_py.debug( '[Per Point Loss Metric] ref point cloud metric min {} max {} median {}' .format(per_point_metric.min(), per_point_metric.max(), per_point_metric.median())) self.model.max_iso_per_batch = max_iso_per_batch # or curvature is used elif mode == 'curvature': curvatures, _ = estimate_pointcloud_local_coord_frames( ref_pcl, neighborhood_size=12, disambiguate_directions=False, return_knn_result=False) # high curvature area : variance in the local frame is large curvatures = curvatures[..., 0] / \ eps_denom(curvatures[..., -1]) per_point_metric = curvatures.view(-1, 1) ref_pcl = ref_pcl.update_features_(per_point_metric) return ref_pcl
def _compute_sampling_metrics(pred_points, pred_normals, gt_points, gt_normals, thresholds, eps): """ Compute metrics that are based on sampling points and normals: - L2 Chamfer distance - Precision at various thresholds - Recall at various thresholds - F1 score at various thresholds - Normal consistency (if normals are provided) - Absolute normal consistency (if normals are provided) Inputs: - pred_points: Tensor of shape (N, S, 3) giving coordinates of sampled points for each predicted mesh - pred_normals: Tensor of shape (N, S, 3) giving normals of points sampled from the predicted mesh, or None if such normals are not available - gt_points: Tensor of shape (N, S, 3) giving coordinates of sampled points for each ground-truth mesh - gt_normals: Tensor of shape (N, S, 3) giving normals of points sampled from the ground-truth verts, or None of such normals are not available - thresholds: Distance thresholds to use for precision / recall / F1 - eps: epsilon value to handle numerically unstable F1 computation Returns: - metrics: A dictionary where keys are metric names and values are Tensors of shape (N,) giving the value of the metric for the batch """ metrics = {} lengths_pred = torch.full((pred_points.shape[0], ), pred_points.shape[1], dtype=torch.int64, device=pred_points.device) lengths_gt = torch.full((gt_points.shape[0], ), gt_points.shape[1], dtype=torch.int64, device=gt_points.device) # For each predicted point, find its neareast-neighbor GT point knn_pred = knn_points(pred_points, gt_points, lengths1=lengths_pred, lengths2=lengths_gt, K=1) # Compute L1 and L2 distances between each pred point and its nearest GT pred_to_gt_dists2 = knn_pred.dists[..., 0] # (N, S) pred_to_gt_dists = pred_to_gt_dists2.sqrt() # (N, S) if gt_normals is not None: pred_normals_near = knn_gather(gt_normals, knn_pred.idx, lengths_gt)[..., 0, :] # (N, S, 3) else: pred_normals_near = None # For each GT point, find its nearest-neighbor predicted point knn_gt = knn_points(gt_points, pred_points, lengths1=lengths_gt, lengths2=lengths_pred, K=1) # Compute L1 and L2 dists between each GT point and its nearest pred point gt_to_pred_dists2 = knn_gt.dists[..., 0] # (N, S) gt_to_pred_dists = gt_to_pred_dists2.sqrt() # (N, S) if pred_normals is not None: gt_normals_near = knn_gather(pred_normals, knn_gt.idx, lengths_pred)[..., 0, :] # (N, S, 3) else: gt_normals_near = None # Compute L2 chamfer distances chamfer_l2 = pred_to_gt_dists2.mean(dim=1) + gt_to_pred_dists2.mean(dim=1) metrics["Chamfer-L2"] = chamfer_l2 # Compute normal consistency and absolute normal consistance only if # we actually got normals for both meshes if pred_normals is not None and gt_normals is not None: pred_to_gt_cos = F.cosine_similarity(pred_normals, pred_normals_near, dim=2) gt_to_pred_cos = F.cosine_similarity(gt_normals, gt_normals_near, dim=2) pred_to_gt_cos_sim = pred_to_gt_cos.mean(dim=1) pred_to_gt_abs_cos_sim = pred_to_gt_cos.abs().mean(dim=1) gt_to_pred_cos_sim = gt_to_pred_cos.mean(dim=1) gt_to_pred_abs_cos_sim = gt_to_pred_cos.abs().mean(dim=1) normal_dist = 0.5 * (pred_to_gt_cos_sim + gt_to_pred_cos_sim) abs_normal_dist = 0.5 * (pred_to_gt_abs_cos_sim + gt_to_pred_abs_cos_sim) metrics["NormalConsistency"] = normal_dist metrics["AbsNormalConsistency"] = abs_normal_dist # Compute precision, recall, and F1 based on L2 distances for t in thresholds: precision = 100.0 * (pred_to_gt_dists < t).float().mean(dim=1) recall = 100.0 * (gt_to_pred_dists < t).float().mean(dim=1) f1 = (2.0 * precision * recall) / (precision + recall + eps) metrics["Precision@%f" % t] = precision metrics["Recall@%f" % t] = recall metrics["F1@%f" % t] = f1 # Move all metrics to CPU metrics = {k: v.cpu() for k, v in metrics.items()} return metrics