Ejemplo n.º 1
0
def _get_kappa_adv(adv_pc, ori_pc, ori_normal, k=2):
    b, _, n = adv_pc.size()
    # compute knn between advPC and oriPC to get normal n_p
    #intra_dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1)
    #intra_idx = torch.topk(intra_dis, 1, dim=2, largest=False, sorted=True)[1]
    #normal = torch.gather(ori_normal, 2, intra_idx.view(b,1,n).expand(b,3,n))
    intra_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           ori_pc.permute(0, 2, 1),
                           K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    normal = knn_gather(ori_normal.permute(0, 2, 1), intra_KNN.idx).permute(
        0, 3, 1, 2).squeeze(3).contiguous()  # [b, 3, n]

    # compute knn between advPC and itself to get \|q-p\|_2
    #inter_dis = ((adv_pc.unsqueeze(3) - adv_pc.unsqueeze(2))**2).sum(1)
    #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous()
    #nn_pts = torch.gather(adv_pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k)
    inter_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           adv_pc.permute(0, 2, 1),
                           K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
    nn_pts = knn_gather(adv_pc.permute(0, 2, 1), inter_KNN.idx).permute(
        0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]
    vectors = nn_pts - adv_pc.unsqueeze(3)
    vectors = _normalize(vectors)

    return torch.abs(
        (vectors *
         normal.unsqueeze(3)).sum(1)).mean(2), normal  # [b, n], [b, 3, n]
Ejemplo n.º 2
0
    def get_normal_w(self,
                     point_clouds: PointClouds3D,
                     normals: Optional[torch.Tensor] = None,
                     **kwargs):
        """
        Weights exp(-\|n-ni\|^2/sharpness_sigma^2), for i in a local neighborhood
        Args:
            point_clouds: whose normals will be used for ni
            normals (tensor): (N, maxP, 3) padded normals as n, if not provided, use
                the normals from point_clouds
        Returns:
            weight per point per neighbor (N,maxP,K)
        """
        self.sharpness_sigma = kwargs.get('sharpness_sigma',
                                          self.sharpness_sigma)
        inv_sigma_normal = 1 / (self.sharpness_sigma * self.sharpness_sigma)
        lengths = point_clouds.num_points_per_cloud()

        if normals is None:
            normals = point_clouds.normals_padded()
        knn_normals = ops.knn_gather(normals, self.knn_tree.idx, lengths)
        normals = torch.nn.functional.normalize(normals, dim=-1)
        knn_normals = torch.nn.functional.normalize(knn_normals, dim=-1)
        w = knn_normals - normals[:, :, None, :]

        w = torch.exp(-torch.sum(w * w, dim=-1) * inv_sigma_normal)
        return w
Ejemplo n.º 3
0
def _get_kappa_ori(pc, normal, k=2):
    b, _, n = pc.size()
    #inter_dis = ((pc.unsqueeze(3) - pc.unsqueeze(2))**2).sum(1)
    #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous()
    #nn_pts = torch.gather(pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k)
    inter_KNN = knn_points(pc.permute(0, 2, 1), pc.permute(0, 2, 1),
                           K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
    nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute(
        0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]
    vectors = nn_pts - pc.unsqueeze(3)
    vectors = _normalize(vectors)

    return torch.abs((vectors * normal.unsqueeze(3)).sum(1)).mean(2)  # [b, n]
Ejemplo n.º 4
0
    def _denoise_normals(self,
                         point_clouds,
                         weights,
                         point_clouds_filter=None):
        """
        robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average
        from neighboring normals
        do this only for invisible points (?)
        Args:
            weights (tensors): (N,max_P,K)
        """
        lengths = point_clouds.num_points_per_cloud()
        P_total = lengths.sum().item()
        normals = point_clouds.normals_padded()
        batch_size, max_P, _ = normals.shape

        knn_normals = ops.knn_gather(normals, self.knn_tree.idx, lengths)
        normals_denoised = torch.sum(knn_normals * weights[:, :, :, None], dim=-2) / \
            eps_denom(torch.sum(weights, dim=-1, keepdim=True))

        # get point visibility so that we update only the non-visible or out-of-mask normals
        if point_clouds_filter is not None:
            try:
                reliable_normals_mask = point_clouds_filter.visibility & point_clouds_filter.inmask
                if len(point_clouds) != reliable_normals_mask.shape[0]:
                    if len(point_clouds
                           ) == 1 and reliable_normals_mask.shape[0] > 1:
                        reliable_normals_mask = reliable_normals_mask.any(
                            dim=0, keepdim=True)
                    else:
                        ValueError(
                            "Incompatible point clouds {} and mask {}".format(
                                len(point_clouds),
                                reliable_normals_mask.shape))

                # found visibility 0/1 as the last dimension of the features
                # reset visible points normals to its original ones
                normals_denoised[pts_reliable_normals_mask == 1] = normals[
                    reliable_normals_mask == 1]
            except KeyError as e:
                pass

        normals_packed = point_clouds.normals_packed()
        normals_denoised_packed = ops.padded_to_packed(
            normals_denoised, point_clouds.cloud_to_packed_first_idx(),
            P_total)
        point_clouds.update_normals_(normals_denoised_packed)
        return point_clouds
Ejemplo n.º 5
0
def knn_group_withI(points1, points2, intensity2, k):
    '''
    Input:
        points1: [B,3,N]
        points2: [B,3,N]
        intensity2: [B,1,N]
    '''
    points1 = points1.permute(0,2,1).contiguous()
    points2 = points2.permute(0,2,1).contiguous()
    _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True)
    points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1) # [B,M,k,3]
    grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True)
    grouped_features = knn_gather(intensity2.permute(0,2,1), nn_idx) # [B,M,k,1]
    new_features = torch.cat([points_resi, grouped_dist], dim=-1)
    
    # [B,5,M,k], [B,3,M,k], [B,1,M,k]
    return new_features.permute(0,3,1,2).contiguous(), \
        nn.permute(0,3,1,2).contiguous(), \
        grouped_features.permute(0,3,1,2).contiguous()
Ejemplo n.º 6
0
def estimate_normal_via_ori_normal(pc_adv, pc_ori, normal_ori, k):
    # pc_adv, pc_ori, normal_ori : [b,3,n]
    b, _, n = pc_adv.size()
    intra_KNN = knn_points(pc_adv.permute(0, 2, 1),
                           pc_ori.permute(0, 2, 1),
                           K=k)  #[dists:[b,n,k], idx:[b,n,k]]
    inter_value = intra_KNN.dists[:, :, 0].contiguous()
    inter_idx = intra_KNN.idx.permute(0, 2, 1).contiguous()
    normal_pts = knn_gather(normal_ori.permute(
        0, 2, 1), intra_KNN.idx).permute(0, 3, 1,
                                         2).contiguous()  # [b, 3, n ,k]

    normal_pts_avg = normal_pts.mean(dim=-1)
    normal_pts_avg = normal_pts_avg / (normal_pts_avg.norm(dim=1) + 1e-12)

    # If the points are not modified (distance = 0), use the normal directly from the original
    # one. Otherwise, use the mean of the normals of the k-nearest points.
    normal_ori_select = normal_pts[:, :, :, 0]
    condition = (inter_value < 1e-6).unsqueeze(1).expand_as(normal_ori_select)
    normals_estimated = torch.where(condition, normal_ori_select,
                                    normal_pts_avg)

    return normals_estimated
Ejemplo n.º 7
0
    def knn_group(self, points1, points2, features2, k):
        '''
        For each point in points1, query kNN points/features in points2/features2
        Input:
            points1: [B,3,N]
            points2: [B,3,N]
            features2: [B,C,N]
        Output:
            new_features: [B,4,N]
            nn: [B,3,N]
            grouped_features: [B,C,N]
        '''
        points1 = points1.permute(0,2,1).contiguous()
        points2 = points2.permute(0,2,1).contiguous()
        _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True)
        points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1)
        grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True)
        grouped_features = knn_gather(features2.permute(0,2,1), nn_idx)
        new_features = torch.cat([points_resi, grouped_dist], dim=-1)

        return new_features.permute(0,3,1,2).contiguous(),\
            nn.permute(0,3,1,2).contiguous(),\
            grouped_features.permute(0,3,1,2).contiguous()
Ejemplo n.º 8
0
    def forward(self, depth_img_dict_1, depth_img_dict_2, cam_info, R12, t12, R21, t21):

        ### format them into point clouds 
        pcl_gt_1, pcl_gt_1_in_2 = self.load_pcl(cam_info, depth_img_dict_1, R21, t21)
        pcl_gt_2, pcl_gt_2_in_1 = self.load_pcl(cam_info, depth_img_dict_2, R12, t12)


        ### knn
        dists_1, idxs_1, pcl_knn_to_1 = knn_pcl(pcl_gt_1, pcl_gt_2_in_1, self.num_nn) #(N, P1, D), (N, P2, D) -> (N, P1, K), (N, P1, K), (N, P1, K, D)
        dists_2, idxs_2, pcl_knn_to_2 = knn_pcl(pcl_gt_2, pcl_gt_1_in_2, self.num_nn)

        ### distance kernel 
        ell = self.ell_min + self.ell_rand
        length_scale_1 = ell * (pcl_gt_1.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist    # points_padded is (N, P, D), length_scale is (N,P,1)
        length_scale_1 = length_scale_1.clamp(min=ell).pow(2)
        length_scale_2 = ell * (pcl_gt_2.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist    # points_padded is (N, P, D), length_scale is (N,P,1)
        length_scale_2 = length_scale_2.clamp(min=ell).pow(2)

        dist_kernel_1 = torch.exp( - dists_1 / length_scale_1 )
        dist_kernel_2 = torch.exp( - dists_2 / length_scale_2 )

        ### color kernel
        color_scale = 0.2
        pcl_hsv_knn_to_1 = knn_gather(pcl_gt_2_in_1.features_padded()[:, :, :3], idxs_1, pcl_gt_2_in_1.num_points_per_cloud() )   # (N, P1, K, D)
        pcl_hsv_knn_to_2 = knn_gather(pcl_gt_1_in_2.features_padded()[:, :, :3], idxs_2, pcl_gt_1_in_2.num_points_per_cloud() )

        color_dist_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_1).norm(-1) # .pow(2).sum(-1)  # (N, P1, K)
        color_dist_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_2).norm(-1) # .pow(2).sum(-1)  # (N, P2, K)

        color_kernel_1 = torch.exp( - color_dist_1 / color_scale )
        color_kernel_2 = torch.exp( - color_dist_2 / color_scale )

        ### calculate normal
        normal_gt_1 = pcl_gt_1.normals_padded()
        normal_gt_2 = pcl_gt_2.normals_padded()
        normal_gt_1_in_2 = pcl_gt_1_in_2.normals_padded()
        normal_gt_2_in_1 = pcl_gt_2_in_1.normals_padded()
        
        ### nres
        nres_gt_1 = pcl_gt_1.features_padded()[:,:,[3]]
        nres_gt_2 = pcl_gt_2.features_padded()[:,:,[3]]
        nres_gt_1_in_2 = pcl_gt_1_in_2.features_padded()[:,:,[3]]
        nres_gt_2_in_1 = pcl_gt_2_in_1.features_padded()[:,:,[3]]

        normal_knn_to_1 = knn_gather(normal_gt_2_in_1, idxs_1, pcl_gt_2_in_1.num_points_per_cloud() )   # (N, P1, K, D)
        normal_knn_to_2 = knn_gather(normal_gt_1_in_2, idxs_2, pcl_gt_1_in_2.num_points_per_cloud() )   # (N, P1, K, D)

        nres_knn_to_1 = knn_gather(nres_gt_2_in_1, idxs_1, pcl_gt_2_in_1.num_points_per_cloud() )   # (N, P1, K, D)
        nres_knn_to_2 = knn_gather(nres_gt_1_in_2, idxs_2, pcl_gt_1_in_2.num_points_per_cloud() )   # (N, P1, K, D)
        
        alpha_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_to_1 ).squeeze(-1)
        alpha_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_to_2 ).squeeze(-1)

        normal_kernel_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_to_1).sum(-1)*alpha_1).clamp(min=0)
        normal_kernel_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_to_2).sum(-1)*alpha_2).clamp(min=0)
        
        ### final kernel
        mask_1 = mask_from_pcl(pcl_gt_1)
        mask_2 = mask_from_pcl(pcl_gt_2)

        kernel_1 = (dist_kernel_1 * color_kernel_1 * normal_kernel_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn)
        kernel_2 = (dist_kernel_2 * color_kernel_2 * normal_kernel_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn)

        if self.log_loss:
            inp = (kernel_1.log() + kernel_2.log()) / 2
        else:
            inp = (kernel_1 + kernel_2) / 2

        return inp
Ejemplo n.º 9
0
    def forward(self, depth_img_dict_1=None, depth_img_dict_2=None, flow_dict_1to2=None, flow_dict_2to1=None, cam_info=None):

        ### format them into point clouds 
        pcl_pred_1, pcl_gt_1, pcl_flowed_2_from_1 = self.load_pcl(cam_info, depth_img_dict_1, flow_dict_1to2)
        pcl_pred_2, pcl_gt_2, pcl_flowed_1_from_2 = self.load_pcl(cam_info, depth_img_dict_2, flow_dict_2to1)

        # logging.info("n gt: {}".format(pcl_gt_1.num_points_per_cloud()))
        # logging.info("n pred: {}".format(pcl_pred_1.num_points_per_cloud()))
        # logging.info("n flow: {}".format(pcl_flowed_1_from_2.num_points_per_cloud()))
        
        # pclpad_pred_1 = pcl_pred_1.points_padded()
        # pclpad_pred_2 = pcl_pred_2.points_padded()
        # pclpad_gt_1 = pcl_gt_1.points_padded()
        # pclpad_gt_2 = pcl_gt_2.points_padded()
        # pclpad_flowed_2_from_1 = pcl_flowed_2_from_1.points_padded()
        # pclpad_flowed_1_from_2 = pcl_flowed_1_from_2.points_padded()

        ### knn
        dists_1, idxs_1, pcl_knn_to_1 = knn_pcl(pcl_gt_1, pcl_pred_1, self.num_nn) #(N, P1, D), (N, P2, D) -> (N, P1, K), (N, P1, K), (N, P1, K, D)
        dists_2, idxs_2, pcl_knn_to_2 = knn_pcl(pcl_gt_2, pcl_pred_2, self.num_nn)
        dists_1_from_2, idxs_1_from_2, pcl_knn_flowed_to_1 = knn_pcl(pcl_gt_1, pcl_flowed_1_from_2, self.num_nn)
        dists_2_from_1, idxs_2_from_1, pcl_knn_flowed_to_2 = knn_pcl(pcl_gt_2, pcl_flowed_2_from_1, self.num_nn)

        # pcl_knn_to_1 = pcl_from_knnidx(pcl_pred_1, idxs_1)
        # pcl_knn_to_2 = pcl_from_knnidx(pcl_pred_2, idxs_2)
        # pytorch3d.ops.knn_points(pclpad_gt_1, pclpad_pred_1, num_points_per_cloud, lengths2: Optional[torch.Tensor] = None, K: int = 1, version: int = -1, return_nn: bool = False, return_sorted: bool = True)

        ### distance kernel 
        ell = self.ell_min + self.ell_rand
        length_scale_1 = ell * (pcl_gt_1.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist    # points_padded is (N, P, D), length_scale is (N,P,1)
        length_scale_1 = length_scale_1.clamp(min=ell).pow(2)
        length_scale_2 = ell * (pcl_gt_2.points_padded()[:,:, [2]]-self.ell_basedist) / self.ell_basedist    # points_padded is (N, P, D), length_scale is (N,P,1)
        length_scale_2 = length_scale_2.clamp(min=ell).pow(2)

        dist_kernel_1 = torch.exp( - dists_1 / length_scale_1 )
        dist_kernel_flowed_1 = torch.exp( - dists_1_from_2 / length_scale_1 )
        dist_kernel_2 = torch.exp( - dists_2 / length_scale_2 )
        dist_kernel_flowed_2 = torch.exp( - dists_2_from_1 / length_scale_2 )

        ### color kernel
        color_scale = 0.2
        pcl_hsv_knn_to_1 = knn_gather(pcl_pred_1.features_padded()[:, :, :3], idxs_1, pcl_pred_1.num_points_per_cloud() )   # (N, P1, K, D)
        pcl_hsv_knn_to_2 = knn_gather(pcl_pred_2.features_padded()[:, :, :3], idxs_2, pcl_pred_2.num_points_per_cloud() )
        pcl_hsv_knn_flowed_to_1 = knn_gather(pcl_flowed_1_from_2.features_padded()[:, :, :3], idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() )
        pcl_hsv_knn_flowed_to_2 = knn_gather(pcl_flowed_2_from_1.features_padded()[:, :, :3], idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() )

        color_dist_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_1).norm(-1) # .pow(2).sum(-1)  # (N, P1, K)
        color_dist_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_to_2).norm(-1) # .pow(2).sum(-1)  # (N, P2, K)
        color_dist_flowed_1 = (pcl_gt_1.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_flowed_to_1).norm(-1) # .pow(2).sum(-1)  # (N, P1, K)
        color_dist_flowed_2 = (pcl_gt_2.features_padded()[:, :, :3].unsqueeze(2) - pcl_hsv_knn_flowed_to_2).norm(-1) # .pow(2).sum(-1)  # (N, P2, K)

        color_kernel_1 = torch.exp( - color_dist_1 / color_scale )
        color_kernel_2 = torch.exp( - color_dist_2 / color_scale )
        color_kernel_flowed_1 = torch.exp( - color_dist_flowed_1 / color_scale )
        color_kernel_flowed_2 = torch.exp( - color_dist_flowed_2 / color_scale )

        ### normal kernel

        # ### calculate normal
        # normal_gt_1 = estimate_pointcloud_normals(pcl_gt_1) # (N, P, 3)
        # normal_gt_2 = estimate_pointcloud_normals(pcl_gt_2) 
        # normal_pred_1 = estimate_pointcloud_normals(pcl_pred_1) 
        # normal_pred_2 = estimate_pointcloud_normals(pcl_pred_2) 
        # if flow_dict_1to2 is not None and flow_dict_2to1 is not None:
        #     normal_flowed_2_from_1 = estimate_pointcloud_normals(pcl_flowed_2_from_1) 
        #     normal_flowed_1_from_2 = estimate_pointcloud_normals(pcl_flowed_1_from_2)
        
        ### calculate normal
        normal_gt_1 = pcl_gt_1.normals_padded()
        normal_gt_2 = pcl_gt_2.normals_padded()
        normal_pred_1 = pcl_pred_1.normals_padded()
        normal_pred_2 = pcl_pred_2.normals_padded()
        if flow_dict_1to2 is not None and flow_dict_2to1 is not None:
            normal_flowed_2_from_1 = pcl_flowed_2_from_1.normals_padded()
            normal_flowed_1_from_2 = pcl_flowed_1_from_2.normals_padded()

        ### nres
        nres_gt_1 = pcl_gt_1.features_padded()[:,:,[3]]
        nres_gt_2 = pcl_gt_2.features_padded()[:,:,[3]]
        nres_pred_1 = pcl_pred_1.features_padded()[:,:,[3]]
        nres_pred_2 = pcl_pred_2.features_padded()[:,:,[3]]
        if flow_dict_1to2 is not None and flow_dict_2to1 is not None:
            nres_flowed_2_from_1 = pcl_flowed_2_from_1.features_padded()[:,:,[3]]
            nres_flowed_1_from_2 = pcl_flowed_1_from_2.features_padded()[:,:,[3]]

        # float res = pts_nres[0][0][in] + grid_nres[ib][0][v+innh][u+innw];
        #   float alpha = 2 * mag_min / (2*mag_min/mag_max + res);

        normal_knn_to_1 = knn_gather(normal_pred_1, idxs_1, pcl_pred_1.num_points_per_cloud() )   # (N, P1, K, D)
        normal_knn_to_2 = knn_gather(normal_pred_2, idxs_2, pcl_pred_2.num_points_per_cloud() )   # (N, P1, K, D)
        normal_knn_flowed_to_1 = knn_gather(normal_flowed_1_from_2, idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() )   # (N, P1, K, D)
        normal_knn_flowed_to_2 = knn_gather(normal_flowed_2_from_1, idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() )   # (N, P1, K, D)

        nres_knn_to_1 = knn_gather(nres_pred_1, idxs_1, pcl_pred_1.num_points_per_cloud() )   # (N, P1, K, D)
        nres_knn_to_2 = knn_gather(nres_pred_2, idxs_2, pcl_pred_2.num_points_per_cloud() )   # (N, P1, K, D)
        nres_knn_flowed_to_1 = knn_gather(nres_flowed_1_from_2, idxs_1_from_2, pcl_flowed_1_from_2.num_points_per_cloud() )   # (N, P1, K, D)
        nres_knn_flowed_to_2 = knn_gather(nres_flowed_2_from_1, idxs_2_from_1, pcl_flowed_2_from_1.num_points_per_cloud() )   # (N, P1, K, D)

        alpha_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_to_1 ).squeeze(-1)
        alpha_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_to_2 ).squeeze(-1)
        alpha_flowed_1 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_1.unsqueeze(2) + nres_knn_flowed_to_1 ).squeeze(-1)
        alpha_flowed_2 = 2 * self.res_mag_min / (2*self.res_mag_min/self.res_mag_max + nres_gt_2.unsqueeze(2) + nres_knn_flowed_to_2 ).squeeze(-1)

        normal_kernel_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_to_1).sum(-1)*alpha_1).clamp(min=0)
        normal_kernel_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_to_2).sum(-1)*alpha_2).clamp(min=0)
        normal_kernel_flowed_1 = ((normal_gt_1.unsqueeze(2) * normal_knn_flowed_to_1).sum(-1)*alpha_flowed_1).clamp(min=0)
        normal_kernel_flowed_2 = ((normal_gt_2.unsqueeze(2) * normal_knn_flowed_to_2).sum(-1)*alpha_flowed_2).clamp(min=0)

        ### final kernel
        mask_1 = mask_from_pcl(pcl_gt_1)
        mask_2 = mask_from_pcl(pcl_gt_2)

        kernel_1 = (dist_kernel_1 * color_kernel_1 * normal_kernel_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn)
        kernel_2 = (dist_kernel_2 * color_kernel_2 * normal_kernel_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn)

        kernel_flowed_1 = (dist_kernel_flowed_1 * color_kernel_flowed_1 * normal_kernel_flowed_1 * mask_1).sum() / (pcl_gt_1.num_points_per_cloud().sum()*self.num_nn)
        kernel_flowed_2 = (dist_kernel_flowed_2 * color_kernel_flowed_2 * normal_kernel_flowed_2 * mask_2).sum() / (pcl_gt_2.num_points_per_cloud().sum()*self.num_nn)

        if self.log_loss:
            inp = kernel_1.log() + kernel_2.log() + kernel_flowed_1.log() + kernel_flowed_2.log()
        else:
            inp = kernel_1 + kernel_2 + kernel_flowed_1 + kernel_flowed_2
        
        return inp
Ejemplo n.º 10
0
    def compute(self,
                point_clouds: PointClouds3D,
                points_filters=None,
                rebuild_knn=True,
                **kwargs):

        self.knn_tree = kwargs.get('knn_tree', self.knn_tree)
        self.knn_mask = kwargs.get('knn_mask', self.knn_mask)

        lengths = point_clouds.num_points_per_cloud()
        P_total = lengths.sum().item()
        points_padded = point_clouds.points_padded()

        # Compute necessary weights to project points to local plane
        # TODO(yifan): This part is same as ProjectionLoss
        # how can we at best save repetitive computation
        with torch.autograd.no_grad():
            if rebuild_knn or self.knn_tree is None or points_padded.shape[:
                                                                           2] != self.knn_tree.shape[:
                                                                                                     2]:
                self._build_knn(point_clouds)

            phi = self.get_phi(point_clouds, **kwargs)

            self._denoise_normals(point_clouds, phi, points_filters)

            # compute wn and wr
            # TODO(yifan): visibility weight?
            normal_w = self.get_normal_w(point_clouds, **kwargs)

            # update normals for a second iteration (?) Eq.(10)
            point_clouds = self._denoise_normals(point_clouds, phi * normal_w,
                                                 points_filters)

            # compose weights
            weights = phi * normal_w
            weights[~self.knn_mask] = 0

            # outside filter_scale*local_point_spacing weights
            mask_ball_query = self.knn_tree.dists > (
                self.filter_scale * self.knn_tree.dists[:, :, :1] * 2.0)
            weights[mask_ball_query] = 0.0

        # project the point to a local surface
        knn_normals = ops.knn_gather(point_clouds.normals_padded(),
                                     self.knn_tree.idx, lengths)
        dist_to_surface = torch.sum(
            (self.knn_tree.knn.detach() - points_padded.unsqueeze(-2)) *
            knn_normals,
            dim=-1)
        deltap = torch.sum(
            dist_to_surface[..., None] * weights[..., None] * knn_normals,
            dim=-2) / eps_denom(torch.sum(weights, dim=-1, keepdim=True))
        points_projected = points_padded + deltap

        if get_debugging_mode():
            # points_padded.requires_grad_(True)

            def save_grad():
                lengths = point_clouds.num_points_per_cloud()

                def _save_grad(grad):
                    dbg_tensor = get_debugging_tensor()
                    if dbg_tensor is None:
                        logger_py.error("dbg_tensor is None")
                    if grad is None:
                        logger_py.error('grad is None')
                    # a dict of list of tensors
                    dbg_tensor.pts_world_grad['repel'] = [
                        grad[b, :lengths[b]].detach().cpu()
                        for b in range(grad.shape[0])
                    ]

                return _save_grad

            dbg_tensor = get_debugging_tensor()
            dbg_tensor.pts_world['repel'] = [
                points_padded[b, :lengths[b]].detach().cpu()
                for b in range(points_padded.shape[0])
            ]
            handle = points_padded.register_hook(save_grad())
            self.hooks.append(handle)

        with torch.autograd.no_grad():
            spatial_w = self.get_spatial_w(point_clouds, points_projected)
            # density_w = self.get_density_w(point_clouds)
            # density weight is actually spatial_w + 1
            density_w = torch.sum(spatial_w, dim=-1, keepdim=True) + 1.0
            weights = normal_w * spatial_w * density_w
            weights[~self.knn_mask] = 0
            weights[mask_ball_query] = 0

        deltap = points_projected[:, :, None, :] - self.knn_tree.knn.detach()
        point_to_point_dist = torch.sum(deltap * deltap, dim=-1)

        # convert everything to packed
        weights = ops.padded_to_packed(
            weights, point_clouds.cloud_to_packed_first_idx(), P_total)
        point_to_point_dist = ops.padded_to_packed(
            point_to_point_dist, point_clouds.cloud_to_packed_first_idx(),
            P_total)

        # we want to maximize this, so negative sign
        point_to_point_dist = -torch.sum(point_to_point_dist * weights,
                                         dim=1) / eps_denom(
                                             torch.sum(weights, dim=1))
        return point_to_point_dist
Ejemplo n.º 11
0
    def compute(self,
                point_clouds: PointClouds3D,
                points_filters=None,
                rebuild_knn=False,
                **kwargs):
        """
        Args:
            point_clouds
            (optional) knn_tree: output from ops.knn_points excluding the query point itself
            (optional) knn_mask: mask valid knn results
        Returns:
            (P, N)
        """
        self.sharpness_sigma = kwargs.get('sharpness_sigma',
                                          self.sharpness_sigma)
        self.filter_scale = kwargs.get('filter_scale', self.filter_scale)
        self.knn_tree = kwargs.get('knn_tree', self.knn_tree)
        self.knn_mask = kwargs.get('knn_mask', self.knn_mask)

        lengths = point_clouds.num_points_per_cloud()
        P_total = lengths.sum().item()
        points = point_clouds.points_padded()
        # - determine phi spatial with using local point spacing (i.e. 2*dist_to_nn)
        # - denoise normals
        # - determine w_normal
        # - mask out values outside ballneighbor i.e. d > filterSpatialScale * localPointSpacing
        # - projected distance dot(ni, x-xi)
        # - multiply and normalize the weights
        with torch.autograd.no_grad():
            if rebuild_knn or self.knn_tree is None or self.knn_tree.idx.shape[:
                                                                               2] != points.shape[:
                                                                                                  2]:
                self._build_knn(point_clouds)

            phi = self.get_phi(point_clouds, **kwargs)

            # robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average
            # from neighboring normals Eq.(11)
            point_clouds = self._denoise_normals(point_clouds, phi,
                                                 points_filters)

            # compute wn and wr
            # TODO(yifan): visibility weight?
            normal_w = self.get_normal_w(point_clouds, **kwargs)
            spatial_w = self.get_spatial_w(point_clouds, **kwargs)

            # update normals for a second iteration (?) Eq.(10)
            point_clouds = self._denoise_normals(point_clouds, phi * normal_w,
                                                 points_filters)

            # compose weights
            weights = phi * spatial_w * normal_w
            weights[~self.knn_mask] = 0

            # outside filter_scale*local_point_spacing weights
            mask_ball_query = self.knn_tree.dists > (
                self.filter_scale * self.knn_tree.dists[:, :, :1] * 2.0)
            weights[mask_ball_query] = 0.0

            # (B, P, k), dot product distance to surface
            # (we need to gather again because the normals have been changed in the denoising step)
            knn_normals = ops.knn_gather(point_clouds.normals_padded(),
                                         self.knn_tree.idx, lengths)

        # if points.requires_grad:
        #     from DSS.core.rasterizer import _dbg_tensor

        #     def save_grad(name):
        #         def _save_grad(grad):
        #             _dbg_tensor[name] = grad.detach().cpu()
        #         return _save_grad
        #     points.register_hook(save_grad('proj_grad'))

        dist_to_surface = torch.sum(
            (self.knn_tree.knn.detach() - points.unsqueeze(-2)) * knn_normals,
            dim=-1)

        if get_debugging_mode():
            # points.requires_grad_(True)

            def save_grad():
                lengths = point_clouds.num_points_per_cloud()

                def _save_grad(grad):
                    dbg_tensor = get_debugging_tensor()
                    if dbg_tensor is None:
                        logger_py.error("dbg_tensor is None")
                    if grad is None:
                        logger_py.error('grad is None')
                    # a dict of list of tensors
                    dbg_tensor.pts_world_grad['proj'] = [
                        grad[b, :lengths[b]].detach().cpu()
                        for b in range(grad.shape[0])
                    ]

                return _save_grad

            dbg_tensor = get_debugging_tensor()
            dbg_tensor.pts_world['proj'] = [
                points[b, :lengths[b]].detach().cpu()
                for b in range(points.shape[0])
            ]
            handle = points.register_hook(save_grad())
            self.hooks.append(handle)

        # convert everything to packed
        weights = ops.padded_to_packed(
            weights, point_clouds.cloud_to_packed_first_idx(), P_total)
        dist_to_surface = ops.padded_to_packed(
            dist_to_surface, point_clouds.cloud_to_packed_first_idx(), P_total)

        # compute weighted signed distance to surface
        dist_to_surface = torch.sum(weights * dist_to_surface,
                                    dim=-1) / eps_denom(
                                        torch.sum(weights, dim=-1))
        loss = dist_to_surface * dist_to_surface
        return loss
Ejemplo n.º 12
0
def estimate_normal(pc, k):
    with torch.no_grad():
        # pc : [b, 3, n]
        b, _, n = pc.size()
        # get knn point set matrix
        inter_KNN = knn_points(pc.permute(0, 2, 1),
                               pc.permute(0, 2, 1),
                               K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
        nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute(
            0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]

        # get covariance matrix and smallest eig-vector of individual point
        normal_vector = []
        for i in range(b):
            if int(torch.__version__.split('.')[1]) >= 4:
                curr_point_set = nn_pts[i].detach().permute(
                    1, 0, 2)  #curr_point_set:[n, 3, k]
                curr_point_set_mean = torch.mean(
                    curr_point_set, dim=2,
                    keepdim=True)  #curr_point_set_mean:[n, 3, 1]
                curr_point_set = curr_point_set - curr_point_set_mean  #curr_point_set:[n, 3, k]
                curr_point_set_t = curr_point_set.permute(
                    0, 2, 1)  #curr_point_set_t:[n, k, 3]
                fact = 1.0 / (k - 1)
                cov_mat = fact * torch.bmm(
                    curr_point_set,
                    curr_point_set_t)  #curr_point_set_t:[n, 3, 3]
                eigenvalue, eigenvector = torch.symeig(
                    cov_mat, eigenvectors=True
                )  # eigenvalue:[n, 3], eigenvector:[n, 3, 3]
                persample_normal_vector = torch.gather(
                    eigenvector, 2,
                    torch.argmin(
                        eigenvalue, dim=1).unsqueeze(1).unsqueeze(2).expand(
                            n, 3,
                            1)).squeeze()  #persample_normal_vector:[n, 3]

                #recorrect the direction via neighbour direction
                nbr_sum = curr_point_set.sum(dim=2)  #curr_point_set:[n, 3]
                sign = -torch.sign(
                    torch.bmm(persample_normal_vector.view(n, 1, 3),
                              nbr_sum.view(n, 3, 1))).squeeze(2)
                persample_normal_vector = sign * persample_normal_vector

                normal_vector.append(persample_normal_vector.permute(1, 0))

            else:
                persample_normal_vector = []
                for j in range(n):
                    curr_point_set = nn_pts[i, :, j, :].cpu()
                    curr_point_set_np = curr_point_set.detach().numpy(
                    )  #curr_point_set_np:[3,k]
                    cov_mat_np = np.cov(curr_point_set_np)  #cov_mat:[3,3]
                    eigenvalue_np, eigenvector_np = np.linalg.eig(
                        cov_mat_np
                    )  #eigenvalue:[3], eigenvector:[3,3]; note that v[:,i] is the eigenvector corresponding to the eigenvalue w[i].
                    curr_normal_vector_np = torch.from_numpy(
                        eigenvector_np[:, np.argmin(eigenvalue_np)]
                    )  #curr_normal_vector:[3]
                    persample_normal_vector.append(curr_normal_vector_np)
                persample_normal_vector = torch.stack(persample_normal_vector,
                                                      1)

                #recorrect the direction via neighbour direction
                nbr_sum = curr_point_set.sum(dim=1)  #curr_point_set:[3]
                sign = -torch.sign(
                    torch.bmm(persample_normal_vector.view(1, 3),
                              nbr_sum.view(3, 1))).squeeze(1)
                persample_normal_vector = sign * persample_normal_vector

                normal_vector.append(persample_normal_vector.permute(1, 0))

                normal_vector.append(persample_normal_vector)

        normal_vector = torch.stack(normal_vector, 0)  #normal_vector:[b, 3, n]
    return normal_vector.float()
Ejemplo n.º 13
0
    def compute(self,
                point_clouds: PointClouds3D,
                points_filter=None,
                rebuild_knn=False,
                **kwargs):
        """
        Args:
            point_clouds
            (optional) knn_tree: output from ops.knn_points excluding the query point itself
            (optional) knn_mask: mask valid knn results
        Returns:
            (P, N)
        """
        self.sharpness_sigma = kwargs.get('sharpness_sigma',
                                          self.sharpness_sigma)
        self.filter_scale = kwargs.get('filter_scale', self.filter_scale)
        self.knn_tree = kwargs.get('knn_tree', self.knn_tree)
        self.knn_mask = kwargs.get('knn_mask', self.knn_mask)

        lengths = point_clouds.num_points_per_cloud()
        P_total = lengths.sum().item()
        points = point_clouds.points_padded()
        # - determine phi spatial with using local point spacing (i.e. 2*dist_to_nn)
        # - denoise normals
        # - determine w_normal
        # - mask out values outside ballneighbor i.e. d > filterSpatialScale * localPointSpacing
        # - projected distance dot(ni, x-xi)
        # - multiply and normalize the weights
        with torch.autograd.no_grad():
            if rebuild_knn or self.knn_tree is None or self.knn_tree.idx.shape[:
                                                                               2] != points.shape[:
                                                                                                  2]:
                self._build_knn(point_clouds)

            phi = self.get_phi(point_clouds, **kwargs)

            # robust normal mollification (Sec 4.4), i.e. replace normals with a weighted average
            # from neighboring normals Eq.(11)
            point_clouds = self._denoise_normals(point_clouds,
                                                 phi,
                                                 points_filter,
                                                 inplace=False)

            # compute wn and wr
            normal_w = self.get_normal_w(point_clouds, **kwargs)
            # visibility weight
            visibility_nb = ops.knn_gather(
                points_filter.visibility.unsqueeze(-1), self.knn_tree.idx,
                lengths)
            visibility_w = visibility_nb.float()
            visibility_w[~visibility_nb] = 0.1
            # compose weights
            weights = phi * normal_w * visibility_w.squeeze(-1)

            # (B, P, k), dot product distance to surface
            knn_normals = ops.knn_gather(point_clouds.normals_padded(),
                                         self.knn_tree.idx, lengths)

        if get_debugging_mode():
            # points.requires_grad_(True)

            def save_grad():
                lengths = point_clouds.num_points_per_cloud()

                def _save_grad(grad):
                    dbg_tensor = get_debugging_tensor()
                    if dbg_tensor is None:
                        logger_py.error("dbg_tensor is None")
                    if grad is None:
                        logger_py.error('grad is None')
                    # a dict of list of tensors
                    dbg_tensor.pts_world_grad['proj'] = [
                        grad[b, :lengths[b]].detach().cpu()
                        for b in range(grad.shape[0])
                    ]

                return _save_grad

            if points.requires_grad:
                dbg_tensor = get_debugging_tensor()
                dbg_tensor.pts_world['proj'] = [
                    points[b, :lengths[b]].detach().cpu()
                    for b in range(points.shape[0])
                ]
                handle = points.register_hook(save_grad())
                self.hooks.append(handle)

        sdf = torch.sum(
            (self.knn_tree.knn.detach() - points.unsqueeze(-2)) * knn_normals,
            dim=-1)

        # convert everything to packed
        weights = ops.padded_to_packed(
            weights, point_clouds.cloud_to_packed_first_idx(), P_total)
        sdf = ops.padded_to_packed(sdf,
                                   point_clouds.cloud_to_packed_first_idx(),
                                   P_total)

        # if get_debugging_mode():
        #     # save to dbg folder as normal
        #     from ..utils.io import save_ply
        #     save_ply('./dbg_repel_diff.ply', point_clouds.points_packed().cpu().detach(), normals=repel_vec.cpu().detach())

        distance_to_face = sdf * sdf
        # compute weighted signed distance to surface
        loss = torch.sum(weights * distance_to_face, dim=-1) / eps_denom(
            torch.sum(weights, dim=-1))

        return loss
Ejemplo n.º 14
0
def run(pointcloud_path, out_dir, decoder_type='siren', resume=True, **kwargs):
    """
    test_implicit_siren_noisy_wNormals
    """
    device = torch.device('cuda:0')

    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    # data
    points, normals = np.split(read_ply(pointcloud_path).astype('float32'),
                               (3, ),
                               axis=1)

    pmax, pmin = points.max(axis=0), points.min(axis=0)
    scale = (pmax - pmin).max()
    pcenter = (pmax + pmin) / 2
    points = (points - pcenter) / scale * 1.5
    scale_mat = scale_mat_inv = np.identity(4)
    scale_mat[[0, 1, 2], [0, 1, 2]] = 1 / scale * 1.5
    scale_mat[[0, 1, 2], [3, 3, 3]] = -pcenter / scale * 1.5
    scale_mat_inv = np.linalg.inv(scale_mat)
    normals = normals @ np.linalg.inv(scale_mat[:3, :3].T)
    object_bounding_sphere = np.linalg.norm(points, axis=1).max()
    pcl = trimesh.Trimesh(vertices=points,
                          vertex_normals=normals,
                          process=False)
    pcl.export(os.path.join(out_dir, "input_pcl.ply"), vertex_normal=True)

    assert (np.abs(points).max() < 1)

    dataset = torch.utils.data.TensorDataset(torch.from_numpy(points),
                                             torch.from_numpy(normals))
    batch_size = 5000
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        num_workers=1,
        shuffle=True,
        collate_fn=tolerating_collate,
    )
    gt_surface_pts_all = torch.from_numpy(points).unsqueeze(0).float()
    gt_surface_normals_all = torch.from_numpy(normals).unsqueeze(0).float()
    gt_surface_normals_all = F.normalize(gt_surface_normals_all, dim=-1)

    if kwargs['use_off_normal_loss']:
        # subsample from pointset
        sub_idx = torch.randperm(gt_surface_normals_all.shape[1])[:20000]
        gt_surface_pts_sub = torch.index_select(gt_surface_pts_all, 1,
                                                sub_idx).to(device=device)
        gt_surface_normals_sub = torch.index_select(gt_surface_normals_all, 1,
                                                    sub_idx).to(device=device)
        gt_surface_normals_sub = denoise_normals(gt_surface_pts_sub,
                                                 gt_surface_normals_sub,
                                                 neighborhood_size=30)

    if decoder_type == 'siren':
        decoder_params = {
            'dim': 3,
            "out_dims": {
                'sdf': 1
            },
            "c_dim": 0,
            "hidden_size": 256,
            'n_layers': 3,
            "first_omega_0": 30,
            "hidden_omega_0": 30,
            "outermost_linear": True,
        }
        decoder = Siren(**decoder_params)
        # pretrained_model_file = os.path.join('data', 'trained_model', 'siren_l{}_c{}_o{}.pt'.format(
        #                     decoder_params['n_layers'], decoder_params['hidden_size'], decoder_params['first_omega_0']))
        # loaded_state_dict = torch.load(pretrained_model_file)
        # decoder.load_state_dict(loaded_state_dict)
    elif decoder_type == 'sdf':
        decoder_params = {
            'dim': 3,
            "out_dims": {
                'sdf': 1
            },
            "c_dim": 0,
            "hidden_size": 512,
            'n_layers': 8,
            'bias': 1.0,
        }
        decoder = SDF(**decoder_params)
    else:
        raise ValueError
    print(decoder)
    decoder = decoder.to(device)

    # training
    total_iter = 30000
    optimizer = torch.optim.Adam(decoder.parameters(), lr=1e-4)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10000, 20000],
                                                     gamma=0.5)

    shape = Shape(gt_surface_pts_all.cuda(),
                  n_points=gt_surface_pts_all.shape[1] // 16,
                  normals=gt_surface_normals_all.cuda())
    # initialize siren with sphere_initialization
    checkpoint_io = CheckpointIO(out_dir, model=decoder, optimizer=optimizer)
    load_dict = dict()
    if resume:
        models_avail = [f for f in os.listdir(out_dir) if f[-3:] == '.pt']
        if len(models_avail) > 0:
            models_avail.sort()
            load_dict = checkpoint_io.load(models_avail[-1])

    it = load_dict.get('it', 0)
    if it > 0:
        try:
            iso_point_files = [
                f for f in os.listdir(out_dir) if f[-7:] == 'iso.ply'
            ]
            iso_point_iters = [
                int(os.path.basename(f[:-len('_iso.ply')]))
                for f in iso_point_files
            ]
            iso_point_iters = np.array(iso_point_iters)
            idx = np.argmax(iso_point_iters[(iso_point_iters - it) <= 0])
            iso_point_file = np.array(iso_point_files)[(iso_point_iters -
                                                        it) <= 0][idx]
            iso_points = torch.from_numpy(
                read_ply(os.path.join(out_dir, iso_point_file))[..., :3])
            shape.points = iso_points.to(device=shape.points.device).view(
                1, -1, 3)
            print('Loaded iso-points from %s' % iso_point_file)
        except Exception as e:
            pass

    # loss
    eikonal_loss = NormalLengthLoss(reduction='mean')

    # start training
    # save_ply(os.path.join(out_dir, 'in_iso_points.ply'), (to_homogen(shape.points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3))
    save_ply(os.path.join(out_dir, 'in_iso_points.ply'),
             shape.points.cpu().view(-1, 3))
    # autograd.set_detect_anomaly(True)
    iso_points = shape.points
    iso_points_normal = None
    while True:
        if (it > total_iter):
            checkpoint_io.save('model_{:04d}.pt'.format(it), it=it)
            mesh = get_surface_high_res_mesh(
                lambda x: decoder(x).sdf.squeeze(), resolution=512)
            mesh.apply_transform(scale_mat_inv)
            mesh.export(os.path.join(out_dir, "final.ply"))
            break

        for batch in data_loader:

            gt_surface_pts, gt_surface_normals = batch
            gt_surface_pts.unsqueeze_(0)
            gt_surface_normals.unsqueeze_(0)
            gt_surface_pts = gt_surface_pts.to(device=device).detach()
            gt_surface_normals = gt_surface_normals.to(device=device).detach()

            optimizer.zero_grad()
            decoder.train()
            loss = defaultdict(float)

            lambda_surface_sdf = 1e3
            lambda_surface_normal = 1e2
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']:
                lambda_surface_sdf = kwargs['lambda_surface_sdf']
                lambda_surface_normal = kwargs['lambda_surface_normal']

            # debug
            if (it - kwargs['warm_up']) % 1000 == 0:
                # generate iso surface
                with torch.autograd.no_grad():
                    box_size = (object_bounding_sphere * 2 + 0.2, ) * 3
                    imgs = plot_cuts(
                        lambda x: decoder(x).sdf.squeeze().detach(),
                        box_size=box_size,
                        max_n_eval_pts=10000,
                        thres=0.0,
                        imgs_per_cut=1,
                        save_path=os.path.join(out_dir, '%010d_iso.html' % it))
                    mesh = get_surface_high_res_mesh(
                        lambda x: decoder(x).sdf.squeeze(), resolution=200)
                    mesh.apply_transform(scale_mat_inv)
                    mesh.export(os.path.join(out_dir, '%010d_mesh.ply' % it))

            if it % 2000 == 0:
                checkpoint_io.save('model.pt', it=it)

            pred_surface_grad = gradient(gt_surface_pts.clone(),
                                         lambda x: decoder(x).sdf)

            # every once in a while update shape and points
            # sample points in space and on the shape
            # use iso points to weigh data points loss
            weights = 1.0
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up']:
                if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and (
                        it -
                        kwargs['warm_up']) % kwargs['resample_every'] == 0:
                    # if shape.points.shape[1]/iso_points.shape[1] < 1.0:
                    #     idx = fps(iso_points.view(-1,3), torch.zeros(iso_points.shape[1], dtype=torch.long, device=iso_points.device), shape.points.shape[1]/iso_points.shape[1])
                    #     iso_points = iso_points.view(-1,3)[idx].view(1,-1,3)

                    iso_points = shape.get_iso_points(
                        iso_points + 0.1 * (torch.rand_like(iso_points) - 0.5),
                        decoder,
                        ear=kwargs['ear'],
                        outlier_tolerance=kwargs['outlier_tolerance'])
                    # iso_points = shape.get_iso_points(shape.points, decoder, ear=kwargs['ear'], outlier_tolerance=kwargs['outlier_tolerance'])
                    iso_points_normal = estimate_pointcloud_normals(
                        iso_points.view(1, -1, 3), 8, False)
                    if kwargs['denoise_normal']:
                        iso_points_normal = denoise_normals(iso_points,
                                                            iso_points_normal,
                                                            num_points=None)
                        iso_points_normal = iso_points_normal.view_as(
                            iso_points)
                elif iso_points_normal is None:
                    iso_points_normal = estimate_pointcloud_normals(
                        iso_points.view(1, -1, 3), 8, False)

                # iso_points = resample_uniformly(iso_points.view(1,-1,3))
                # TODO: use gradient from network or neighborhood?
                iso_points_g = gradient(iso_points.clone(),
                                        lambda x: decoder(x).sdf)
                if it == kwargs['warm_up'] or kwargs['resample_every'] > 0 and (
                        it -
                        kwargs['warm_up']) % kwargs['resample_every'] == 0:
                    # save_ply(os.path.join(out_dir, '%010d_iso.ply' % it), (to_homogen(iso_points).cpu().detach().numpy() @ scale_mat_inv.T)[...,:3].reshape(-1,3), normals=iso_points_g.view(-1,3).detach().cpu())
                    save_ply(os.path.join(out_dir, '%010d_iso.ply' % it),
                             iso_points.cpu().detach().view(-1, 3),
                             normals=iso_points_g.view(-1, 3).detach().cpu())

                if kwargs['weight_mode'] == 1:
                    weights = get_iso_bilateral_weights(
                        gt_surface_pts, gt_surface_normals, iso_points,
                        iso_points_g).detach()
                elif kwargs['weight_mode'] == 2:
                    weights = get_laplacian_weights(gt_surface_pts,
                                                    gt_surface_normals,
                                                    iso_points,
                                                    iso_points_g).detach()
                elif kwargs['weight_mode'] == 3:
                    weights = get_heat_kernel_weights(gt_surface_pts,
                                                      gt_surface_normals,
                                                      iso_points,
                                                      iso_points_g).detach()

                if (it - kwargs['warm_up']
                    ) % 1000 == 0 and kwargs['weight_mode'] != -1:
                    print("min {:.4g}, max {:.4g}, std {:.4g}, mean {:.4g}".
                          format(weights.min(), weights.max(), weights.std(),
                                 weights.mean()))
                    colors = scaler_to_color(1 -
                                             weights.view(-1).cpu().numpy(),
                                             cmap='Reds')
                    save_ply(
                        os.path.join(out_dir, '%010d_batch_weight.ply' % it),
                        (to_homogen(gt_surface_pts).cpu().detach().numpy()
                         @ scale_mat_inv.T)[..., :3].reshape(-1, 3),
                        colors=colors)

                sample_idx = torch.randperm(
                    iso_points.shape[1])[:min(batch_size, iso_points.shape[1])]
                iso_points_sampled = iso_points.detach()[:, sample_idx, :]
                # iso_points_sampled = iso_points.detach()
                iso_points_sdf = decoder(iso_points_sampled.detach()).sdf
                loss_iso_points_sdf = iso_points_sdf.abs().mean(
                ) * kwargs['lambda_iso_sdf'] * iso_points_sdf.nelement() / (
                    iso_points_sdf.nelement() + 8000)
                loss['loss_sdf_iso'] = loss_iso_points_sdf.detach()
                loss['loss'] += loss_iso_points_sdf

                # TODO: predict iso_normals from local_frame
                iso_normals_sampled = iso_points_normal.detach()[:,
                                                                 sample_idx, :]
                iso_g_sampled = iso_points_g[:, sample_idx, :]
                loss_normals = torch.mean(
                    (1 - F.cosine_similarity(
                        iso_normals_sampled, iso_g_sampled, dim=-1).abs())
                ) * kwargs['lambda_iso_normal'] * iso_points_sdf.nelement() / (
                    iso_points_sdf.nelement() + 8000)
                # loss_normals = torch.mean((1 - F.cosine_similarity(iso_points_normal, iso_points_g, dim=-1).abs())) * kwargs['lambda_iso_normal']
                loss['loss_normal_iso'] = loss_normals.detach()
                loss['loss'] += loss_normals

            idx = torch.randperm(gt_surface_pts.shape[1]).to(
                device=gt_surface_pts.device)[:(gt_surface_pts.shape[1] // 2)]
            tmp = torch.index_select(gt_surface_pts, 1, idx)
            space_pts = torch.cat([
                torch.rand_like(tmp) * 2 - 1,
                torch.randn_like(tmp, device=tmp.device, dtype=tmp.dtype) * 0.1
                + tmp
            ],
                                  dim=1)

            space_pts.requires_grad_(True)
            pred_space_sdf = decoder(space_pts).sdf
            pred_space_grad = torch.autograd.grad(
                pred_space_sdf, [space_pts], [torch.ones_like(pred_space_sdf)],
                create_graph=True)[0]

            # 1. eikonal term
            loss_eikonal = (
                eikonal_loss(pred_surface_grad) +
                eikonal_loss(pred_space_grad)) * kwargs['lambda_eikonal']
            loss['loss_eikonal'] = loss_eikonal.detach()
            loss['loss'] += loss_eikonal

            # 2. SDF loss
            # loss on iso points
            pred_surface_sdf = decoder(gt_surface_pts).sdf

            loss_sdf = torch.mean(
                weights * pred_surface_sdf.abs()) * lambda_surface_sdf
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[
                    'lambda_iso_sdf'] != 0:
                # loss_sdf = 0.5 * loss_sdf
                loss_sdf = loss_sdf * pred_surface_sdf.nelement() / (
                    pred_surface_sdf.nelement() + iso_points_sdf.nelement())

            if kwargs['use_sal_loss'] and iso_points is not None:
                dists, idxs, _ = knn_points(space_pts.view(1, -1, 3),
                                            iso_points.view(1, -1, 3).detach(),
                                            K=1)
                dists = dists.view_as(pred_space_sdf)
                idxs = idxs.view_as(pred_space_sdf)
                loss_inter = ((eps_sqrt(dists).sqrt() - pred_space_sdf.abs())**
                              2).mean() * kwargs['lambda_inter_sal']
            else:
                alpha = (it / total_iter + 1) * 100
                loss_inter = torch.exp(
                    -alpha *
                    pred_space_sdf.abs()).mean() * kwargs['lambda_inter_sdf']

            loss_sald = torch.tensor(0.0).cuda()
            # prevent wrong closing for open mesh
            if kwargs['use_off_normal_loss'] and it < 1000:
                dists, idxs, _ = knn_points(space_pts.view(1, -1, 3),
                                            gt_surface_pts_sub.view(1, -1,
                                                                    3).cuda(),
                                            K=1)
                knn_normal = knn_gather(
                    gt_surface_normals_sub.cuda().view(1, -1, 3),
                    idxs).view(1, -1, 3)
                direction_correctness = -F.cosine_similarity(
                    knn_normal, pred_space_grad, dim=-1)
                direction_correctness[direction_correctness < 0] = 0
                loss_sald = torch.mean(
                    direction_correctness * torch.exp(-2 * dists)) * 2

            # 3. normal direction
            loss_normals = torch.mean(weights * (1 - F.cosine_similarity(
                gt_surface_normals, pred_surface_grad, dim=-1))
                                      ) * lambda_surface_normal
            if kwargs['warm_up'] >= 0 and it >= kwargs['warm_up'] and kwargs[
                    'lambda_iso_normal'] != 0:
                # loss_normals = 0.5 * loss_normals
                loss_normals = loss_normals * gt_surface_normals.nelement() / (
                    gt_surface_normals.nelement() +
                    iso_normals_sampled.nelement())

            loss['loss_sdf'] = loss_sdf.detach()
            loss['loss_inter'] = loss_inter.detach()
            loss['loss_normals'] = loss_normals.detach()
            loss['loss_sald'] = loss_sald
            loss['loss'] += loss_sdf
            loss['loss'] += loss_inter
            loss['loss'] += loss_sald
            loss['loss'] += loss_normals

            loss['loss'].backward()
            torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.)

            optimizer.step()
            scheduler.step()
            if it % 20 == 0:
                print("iter {:05d} {}".format(
                    it, ', '.join([
                        '{}: {}'.format(k, v.item()) for k, v in loss.items()
                    ])))

            it += 1
Ejemplo n.º 15
0
    def compute(self,
                point_clouds: PointClouds3D,
                points_filter=None,
                rebuild_knn=True,
                **kwargs):

        self.knn_tree = kwargs.get('knn_tree', self.knn_tree)
        self.knn_mask = kwargs.get('knn_mask', self.knn_mask)

        lengths = point_clouds.num_points_per_cloud()
        P_total = lengths.sum().item()
        points_padded = point_clouds.points_padded()

        if not points_padded.requires_grad:
            logger_py.warn(
                'Computing repulsion loss, but points_padded is not differentiable.'
            )

        # Compute necessary weights to project points to local plane
        # TODO(yifan): This part is same as ProjectionLoss
        # how can we at best save repetitive computation
        with torch.autograd.no_grad():
            if rebuild_knn or self.knn_tree is None or points_padded.shape[:
                                                                           2] != self.knn_tree.shape[:
                                                                                                     2]:
                self._build_knn(point_clouds)
            phi = self.get_phi(point_clouds, **kwargs)
            point_clouds = self._denoise_normals(point_clouds,
                                                 phi,
                                                 points_filter,
                                                 inplace=False)

        # project the point to a local surface
        knn_diff = points_padded.unsqueeze(-2) - self.knn_tree.knn.detach()

        knn_normals = ops.knn_gather(point_clouds.normals_padded(),
                                     self.knn_tree.idx, lengths)
        pts_diff_proj = knn_diff - \
                (knn_diff * knn_normals).sum(dim=-1, keepdim=True) * knn_normals

        if get_debugging_mode():
            # points_padded.requires_grad_(True)

            def save_grad():
                lengths = point_clouds.num_points_per_cloud()

                def _save_grad(grad):
                    dbg_tensor = get_debugging_tensor()
                    if dbg_tensor is None:
                        logger_py.error("dbg_tensor is None")
                    if grad is None:
                        logger_py.error('grad is None')
                    # a dict of list of tensors
                    dbg_tensor.pts_world_grad['repel'] = [
                        grad[b, :lengths[b]].detach().cpu()
                        for b in range(grad.shape[0])
                    ]

                return _save_grad

            if points_padded.requires_grad:
                dbg_tensor = get_debugging_tensor()
                dbg_tensor.pts_world['repel'] = [
                    points_padded[b, :lengths[b]].detach().cpu()
                    for b in range(points_padded.shape[0])
                ]
                handle = points_padded.register_hook(save_grad())
                self.hooks.append(handle)

        with torch.autograd.no_grad():
            spatial_w = self.get_spatial_w(point_clouds, **kwargs)
            # set far neighbors' spatial_w to 0
            normal_w = self.get_normal_w(point_clouds, **kwargs)
            density_w = torch.sum(spatial_w, dim=-1, keepdim=True) + 1.0
            weights = spatial_w * normal_w

        # convert everything to packed
        weights = ops.padded_to_packed(
            weights, point_clouds.cloud_to_packed_first_idx(), P_total)
        pts_diff_proj = ops.padded_to_packed(
            pts_diff_proj.contiguous().view(pts_diff_proj.shape[0],
                                            pts_diff_proj.shape[1], -1),
            point_clouds.cloud_to_packed_first_idx(),
            P_total).view(P_total, -1, 3)
        density_w = ops.padded_to_packed(
            density_w, point_clouds.cloud_to_packed_first_idx(), P_total)

        # we want to maximize this, so negative sign
        repel_vec = torch.sum(pts_diff_proj * weights.unsqueeze(-1),
                              dim=1) / eps_denom(
                                  torch.sum(weights, dim=1).unsqueeze(-1))
        repel_vec = repel_vec * density_w

        loss = torch.exp(-repel_vec.abs())

        # if get_debugging_mode():
        #     # save to dbg folder as normal
        #     from ..utils.io import save_ply
        #     save_ply('./dbg_repel_diff.ply', point_clouds.points_packed().cpu().detach(), normals=repel_vec.cpu().detach())

        return loss
Ejemplo n.º 16
0
def estimate_perpendicular(pc, k, sigma=0.01, clip=0.05):
    with torch.no_grad():
        # pc : [b, 3, n]
        b, _, n = pc.size()
        inter_KNN = knn_points(pc.permute(0, 2, 1),
                               pc.permute(0, 2, 1),
                               K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
        nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute(
            0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]

        # get covariance matrix and smallest eig-vector of individual point
        perpendi_vector_1 = []
        perpendi_vector_2 = []
        for i in range(b):
            curr_point_set = nn_pts[i].detach().permute(
                1, 0, 2)  #curr_point_set:[n, 3, k]
            curr_point_set_mean = torch.mean(
                curr_point_set, dim=2,
                keepdim=True)  #curr_point_set_mean:[n, 3, 1]
            curr_point_set = curr_point_set - curr_point_set_mean  #curr_point_set:[n, 3, k]
            curr_point_set_t = curr_point_set.permute(
                0, 2, 1)  #curr_point_set_t:[n, k, 3]
            fact = 1.0 / (k - 1)
            cov_mat = fact * torch.bmm(
                curr_point_set, curr_point_set_t)  #curr_point_set_t:[n, 3, 3]
            eigenvalue, eigenvector = torch.symeig(
                cov_mat,
                eigenvectors=True)  # eigenvalue:[n, 3], eigenvector:[n, 3, 3]

            larger_dim_idx = torch.topk(eigenvalue,
                                        2,
                                        dim=1,
                                        largest=True,
                                        sorted=False,
                                        out=None)[1]  # eigenvalue:[n, 2]

            persample_perpendi_vector_1 = torch.gather(
                eigenvector,
                2, larger_dim_idx[:, 0].unsqueeze(1).unsqueeze(2).expand(
                    n, 3, 1)).squeeze()  #persample_perpendi_vector_1:[n, 3]
            persample_perpendi_vector_2 = torch.gather(
                eigenvector,
                2, larger_dim_idx[:, 1].unsqueeze(1).unsqueeze(2).expand(
                    n, 3, 1)).squeeze()  #persample_perpendi_vector_2:[n, 3]

            perpendi_vector_1.append(persample_perpendi_vector_1.permute(1, 0))
            perpendi_vector_2.append(persample_perpendi_vector_2.permute(1, 0))

        perpendi_vector_1 = torch.stack(perpendi_vector_1,
                                        0)  #perpendi_vector_1:[b, 3, n]
        perpendi_vector_2 = torch.stack(perpendi_vector_2,
                                        0)  #perpendi_vector_1:[b, 3, n]

        aux_vector1 = sigma * torch.randn(
            b, n).unsqueeze(1).cuda()  #aux_vector1:[b, 1, n]
        aux_vector2 = sigma * torch.randn(
            b, n).unsqueeze(1).cuda()  #aux_vector2:[b, 1, n]

    return torch.clamp(perpendi_vector_1 * aux_vector1,
                       -1 * clip, clip) + torch.clamp(
                           perpendi_vector_2 * aux_vector2, -1 * clip, clip)
Ejemplo n.º 17
0
class Trainer(BaseTrainer):
    def __init__(self,
                 model,
                 optimizer,
                 scheduler,
                 generator,
                 train_loader,
                 val_loader,
                 device='cpu',
                 cameras=None,
                 log_dir=None,
                 vis_dir=None,
                 debug_dir=None,
                 val_dir=None,
                 threshold=0.0,
                 n_training_points=2048,
                 n_eval_points=4000,
                 lambda_occupied=1.,
                 lambda_freespace=1.,
                 lambda_rgb=1.,
                 lambda_eikonal=0.01,
                 patch_size=1,
                 clip_grad=True,
                 reduction_method='sum',
                 sample_continuous=False,
                 overwrite_visualization=True,
                 n_debug_points=-1,
                 saliency_sampling_3d=False,
                 resample_every=-1,
                 refresh_metric_every=-1,
                 gamma_n_points_dss=2.0,
                 gamma_n_rays=0.6,
                 gamma_lambda_rgb=1.0,
                 steps_n_points_dss=-1,
                 steps_n_rays=-1,
                 steps_lambda_rgb=-1,
                 limit_n_points_dss=24000,
                 limit_n_rays=1024,
                 steps_proj_tolerance=-1,
                 gamma_proj_tolerance=0.5,
                 limit_proj_tolerance=5e-5,
                 steps_lambda_sdf=-1,
                 gamma_lambda_sdf=1.0,
                 warm_up_iters=0,
                 sdf_alpha=5.0,
                 limit_sdf_alpha=100,
                 gamma_sdf_alpha=2,
                 steps_sdf_alpha=-1,
                 limit_lambda_freespace=1.0,
                 limit_lambda_occupied=1.0,
                 limit_lambda_rgb=1.0,
                 **kwargs):
        """Initialize the BaseModel class.
        Args:
            model (nn.Module)
            optimizer: optimizer
            scheduler: scheduler
            device: device
        """
        self.cfg = kwargs
        self.device = device
        self.model = model
        self.cameras = cameras

        self.val_loader = val_loader
        self.train_loader = train_loader

        self.tb_logger = SummaryWriter(
            log_dir + datetime.datetime.now().strftime("-%Y%m%d-%H%M%S"))

        # implicit function model
        self.vis_dir = vis_dir
        self.val_dir = val_dir
        self.threshold = threshold

        self.lambda_eikonal = lambda_eikonal
        self.lambda_occupied = lambda_occupied
        self.lambda_freespace = lambda_freespace
        self.lambda_rgb = lambda_rgb
        self.sdf_alpha = sdf_alpha

        self.generator = generator
        self.n_eval_points = n_eval_points
        self.patch_size = patch_size
        self.reduction_method = reduction_method
        self.sample_continuous = sample_continuous
        self.overwrite_visualization = overwrite_visualization
        self.saliency_sampling_3d = saliency_sampling_3d
        self.resample_every = resample_every
        self.refresh_metric_every = refresh_metric_every
        self.warm_up_iters = warm_up_iters

        #  tuple (score, mesh)
        self._mesh_cache = None
        self._pcl_cache = {}
        self.ref_pcl = None

        n_points_per_cloud = 0
        init_proj_tolerance = 0
        if isinstance(self.model, CombinedModel):
            n_points_per_cloud = self.model.n_points_per_cloud
            init_proj_tolerance = self.model.projection.proj_tolerance
        self.training_scheduler = TrainerScheduler(
            init_n_points_dss=n_points_per_cloud,
            init_n_rays=n_training_points,
            init_proj_tolerance=init_proj_tolerance,
            init_lambda_rgb=lambda_rgb,
            init_lambda_freespace=lambda_freespace,
            init_lambda_occupied=lambda_occupied,
            init_sdf_alpha=self.sdf_alpha,
            steps_n_points_dss=steps_n_points_dss,
            steps_n_rays=steps_n_rays,
            steps_proj_tolerance=steps_proj_tolerance,
            steps_sdf_alpha=steps_sdf_alpha,
            steps_lambda_rgb=steps_lambda_rgb,
            steps_lambda_sdf=steps_lambda_sdf,
            warm_up_iters=self.warm_up_iters,
            gamma_n_points_dss=gamma_n_points_dss,
            gamma_n_rays=gamma_n_rays,
            gamma_proj_tolerance=gamma_proj_tolerance,
            gamma_lambda_rgb=gamma_lambda_rgb,
            gamma_sdf_alpha=gamma_sdf_alpha,
            gamma_lambda_sdf=gamma_lambda_sdf,
            limit_n_points_dss=limit_n_points_dss,
            limit_n_rays=limit_n_rays,
            limit_proj_tolerance=limit_proj_tolerance,
            limit_sdf_alpha=limit_sdf_alpha,
            limit_lambda_rgb=limit_lambda_rgb,
            limit_lambda_occupied=limit_lambda_occupied,
            limit_lambda_freespace=limit_lambda_freespace)

        self.debug_dir = debug_dir
        self.hooks = []

        self.n_training_points = n_training_points
        self.n_debug_points = n_debug_points

        self.optimizer = optimizer
        self.scheduler = scheduler
        self.clip_grad = clip_grad

        self.iou_loss = IouLoss(reduction=self.reduction_method,
                                channel_dim=None)
        self.eikonal_loss = NormalLengthLoss(reduction=self.reduction_method)
        self.l1_loss = L1Loss(reduction=self.reduction_method)
        self.l2_loss = L2Loss(reduction=self.reduction_method)
        self.sdf_loss = SDF2DLoss(reduction=self.reduction_method)

    def _query_mesh(self):
        """
        Generate mesh at the current training (it), evaluate
        """
        if self._mesh_cache is None:
            try:
                mesh = self.generator.generate_mesh({},
                                                    with_colors=False,
                                                    with_normals=False)
            except Exception as e:
                return logger_py.error("Couldn\'t generate mesh {}".format(e))
            else:
                if self.cfg['model_selection_mode'] == 'maximize':
                    model_selection_sign = 1
                elif self.cfg['model_selection_mode'] == 'minimize':
                    model_selection_sign = -1

                self._mesh_cache = (-model_selection_sign * float('inf'), mesh)
                self._pcl_cache.clear()

        if self._mesh_cache is not None:
            return self._mesh_cache[1]

    def _query_pcl(self, n_points=-1):
        """
        get a uniform point cloud on the iso-surface, save in a cache
        """
        if n_points < 0 or n_points is None and len(self._pcl_cache) > 0:
            n_points = list(self._pcl_cache.keys())[0]
        iso_pcl = self._pcl_cache.get(n_points, None)
        new_pcl = False
        if iso_pcl is None:
            t0 = time.time()
            iso_pcl = sample_uniform_iso_points(
                self.model.decoder,
                n_points,
                bounding_sphere_radius=self.model.object_bounding_sphere,
                init_points=self.model._points.points_padded())
            t1 = time.time()

            normals = self.model.get_normals_from_grad(iso_pcl.points_padded(),
                                                       requires_grad=False)
            logger_py.debug('[Sample from Mesh] time ellapsed {}s'.format(t1 -
                                                                          t0))
            iso_pcl = PointClouds3D(iso_pcl.points_padded(), normals=normals)
            self._pcl_cache.clear()
            self._pcl_cache[n_points] = iso_pcl
            new_pcl = True
        return new_pcl, iso_pcl

    def evaluate_mesh(self, val_dataloader, it, **kwargs):
        logger_py.info("[Mesh Evaluation]")
        t0 = time.time()
        if not os.path.exists(self.val_dir):
            os.makedirs(self.val_dir)

        eval_list = defaultdict(list)

        mesh_gt = val_dataloader.dataset.get_meshes()
        assert (mesh_gt is not None)
        mesh_gt = mesh_gt.to(device=self.device)

        pointcloud_tgt = val_dataloader.dataset.get_pointclouds(
            num_points=self.n_eval_points)

        mesh = self.generator.generate_mesh({},
                                            with_colors=False,
                                            with_normals=False)
        points_pred = trimesh.sample.sample_surface_even(
            mesh,
            pointcloud_tgt.points_packed().shape[0])
        chamfer_dist = chamfer_distance(
            pointcloud_tgt.points_padded(),
            torch.from_numpy(points_pred).view(1, -1, 3).to(
                device=pointcloud_tgt.points_padded().device,
                dtype=torch.float32))
        eval_dict_mesh = {'chamfer': chamfer_dist.item()}

        # save to "val" dict
        t1 = time.time()
        logger_py.info('[Mesh Evaluation] time ellapsed {}s'.format(t1 - t0))
        if not mesh.is_empty:
            mesh.export(os.path.join(self.val_dir, "%010d.ply" % it))
        return eval_dict_mesh

    def eval_step(self, data, **kwargs):
        """
        evaluate with image mask iou or image rgb psnr
        """
        lights_model = kwargs.get('lights',
                                  self.val_loader.dataset.get_lights())
        cameras_model = kwargs.get('cameras',
                                   self.val_loader.dataset.get_cameras())
        img_size = self.generator.img_size
        eval_dict = {'iou': 0.0, 'psnr': 0.0}
        with autograd.no_grad():
            self.model.eval()
            data = self.process_data_dict(data,
                                          cameras_model,
                                          lights=lights_model)
            img_mask = data['mask_img']
            img = data['img']
            # render image
            rgbas = self.generator.raytrace_images(img_size,
                                                   img_mask,
                                                   cameras=data['camera'],
                                                   lights=data['light'])
            assert (len(rgbas) == 1)
            rgba = rgbas[0]
            rgba = torch.tensor(rgba[None, ...],
                                dtype=torch.float,
                                device=img_mask.device).permute(0, 3, 1, 2)

            # compare iou
            mask_gt = F.interpolate(img_mask.float(),
                                    img_size,
                                    mode='bilinear',
                                    align_corners=False).squeeze(1)
            mask_pred = rgba[:, 3, :, :]
            eval_dict['iou'] += self.iou_loss(mask_gt.float(),
                                              mask_pred.float(),
                                              reduction='mean')

            # compare psnr
            rgb_gt = F.interpolate(img,
                                   img_size,
                                   mode='bilinear',
                                   align_corners=False)
            rgb_pred = rgba[:, :3, :, :]
            eval_dict['psnr'] += self.l2_loss(rgb_gt,
                                              rgb_pred,
                                              channel_dim=1,
                                              reduction='mean',
                                              align_corners=False).detach()

        return eval_dict

    def train_step(self, data, cameras, **kwargs):
        """
        Args:
            data (dict): contains img, img.mask and img.depth and camera_mat
            cameras (Cameras): Cameras object from pytorch3d
        Returns:
            loss
        """
        self.model.train()
        self.optimizer.zero_grad()
        it = kwargs.get("it", None)
        lights = kwargs.get('lights', None)
        if hasattr(self, 'training_scheduler'):
            self.training_scheduler.step(self, it)

        if isinstance(self.model, CombinedModel) and it > self.warm_up_iters:
            if self.resample_every > 0 and (
                    it - self.warm_up_iters) % self.resample_every == 0:
                self._pcl_cache.clear()
            is_new = self.sample_from_mesh(self.model.n_points_per_cloud)
            if self.saliency_sampling_3d:
                refresh_per_point_metric = is_new or (self.refresh_metric_every > 0) and \
                    ((it - 1) % self.refresh_metric_every == 0)
                self.model.eval()
                if refresh_per_point_metric:
                    self.ref_pcl = self.ref_per_point_metric(
                        mode=self.cfg['ref_metric'])
                    colors = scaler_to_color(
                        self.ref_pcl.features_packed().cpu().numpy().reshape(
                            -1))
                    save_ply(
                        os.path.join(self.vis_dir, '%010d_refpcl.ply' % it),
                        self.ref_pcl.points_packed().cpu().numpy(),
                        colors=colors,
                        normals=self.ref_pcl.normals_packed().cpu().numpy())

        data = self.process_data_dict(data, cameras, lights=lights)
        self.model.train()
        # autograd.set_detect_anomaly(True)
        loss = self.compute_loss(data['img'],
                                 data['mask_img'],
                                 data['input'],
                                 data['camera'],
                                 data['light'],
                                 it=it,
                                 ref_pcl=self.ref_pcl)
        loss.backward()
        if self.clip_grad:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                           max_norm=1.)
        self.optimizer.step()
        check_weights(self.model.state_dict())

        return loss.item()

    def process_data_dict(self, data, cameras, lights=None):
        ''' Processes the data dictionary and returns respective tensors

        Args:
            data (dictionary): data dictionary
        '''
        device = self.device

        # Get "ordinary" data
        img = data.get('img.rgb').to(device)
        assert (img.min() >= 0 and img.max() <= 1
                ), "Image must be a floating number between 0 and 1."
        mask_img = data.get('img.mask').to(device)

        camera_mat = data.get('camera_mat', None)

        # inputs for SVR
        inputs = data.get('inputs', torch.empty(0, 0)).to(device)

        # set camera matrix to cameras
        if camera_mat is None:
            logger_py.warning(
                "Camera matrix is not provided! Using the default matrix")
        else:
            cameras.R, cameras.T = decompose_to_R_and_t(camera_mat)
            cameras._N = cameras.R.shape[0]
            cameras.to(device)

        if lights is not None:
            lights_params = data.get('lights', None)
            if lights_params is not None:
                lights = type(lights)(**lights_params).to(device)

        return {
            'img': img,
            'mask_img': mask_img,
            'input': inputs,
            'camera': cameras,
            'light': lights
        }

    def sample_pixels(self, n_rays: int, batch_size: int, h: int, w: int):

        if n_rays >= h * w:
            p = arange_pixels((h, w), batch_size)[1].to(self.device)
        else:
            p = sample_patch_points(
                batch_size,
                n_rays,
                patch_size=self.patch_size,
                image_resolution=(h, w),
                continuous=self.sample_continuous,
            ).to(self.device)
        return p

    def sample_from_mesh(self, n_points):
        """
        Construct mesh from implicit model and sample from the mesh to get iso-surface points,
        which is used for projection in the combined model
        Return True is a new pcl is sampled
        """
        try:
            new_pcl, pcl = self._query_pcl(n_points)
            self.model._points = pcl
            self.model.points = pcl.points_padded()
            if not os.path.exists(self.vis_dir):
                os.makedirs(self.vis_dir)
        except Exception as e:
            logger_py.error("Couldn't sample points from mesh: {}".format(e))
            return False

        return new_pcl

    def compute_loss(self,
                     img,
                     mask_img,
                     inputs,
                     cameras,
                     lights,
                     n_points=None,
                     eval_mode=False,
                     it=None,
                     ref_pcl=None):
        ''' Compute the loss.
        Args:
            data (dict): data dictionary
            eval_mode (bool): whether to use eval mode
            it (int): training iteration
        '''
        # Initialize loss dictionary and other values
        loss = {}

        # overwrite n_points
        if n_points is None:
            n_points = self.n_eval_points if eval_mode else self.n_training_points

        # Shortcuts
        device = self.device
        patch_size = self.patch_size
        reduction_method = self.reduction_method
        batch_size, _, h, w = img.shape

        # Assertions
        assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0))

        # Apply losses
        # 1.) Initialize loss
        loss['loss'] = 0
        if isinstance(self.model, CombinedModel):
            # 1.) Sample points on image plane ("pixels")
            p = None
            if n_points > 0:
                p = self.sample_pixels(n_points, batch_size, h, w)

            project = (it - self.warm_up_iters > 0) and not eval_mode
            sample_iso_offsurface = project
            # TODO: check insertion fix: start using insertion after 5000 iterations
            saliency_sampling_3d = self.saliency_sampling_3d
            model_outputs = self.model(
                mask_img,
                img,
                cameras,
                pixels=p,
                inputs=inputs,
                lights=lights,
                it=it,
                eval_mode=eval_mode,
                project=project,
                sample_iso_offsurface=sample_iso_offsurface,
                proj_kwargs={
                    'ref_pcl': ref_pcl if saliency_sampling_3d else None,
                    'insert': saliency_sampling_3d
                })
        else:
            # 1.) Sample points on image plane ("pixels")
            p = self.sample_pixels(n_points, batch_size, h, w)

            model_outputs = self.model(p,
                                       img,
                                       mask_img,
                                       inputs=inputs,
                                       cameras=cameras,
                                       lights=lights,
                                       it=it)

        point_clouds = model_outputs.get('iso_pcl')
        rgb_gt = model_outputs.get('iso_rgb_gt')
        sdf_freespace = model_outputs.get('sdf_freespace')
        sdf_occupancy = model_outputs.get('sdf_occupancy')

        if it % 50 == 0 and not eval_mode and not get_debugging_mode():
            logger_py.debug('# iso: {}, # occ_off: {}, # free_off: {}'.format(
                model_outputs['iso_rgb_gt'].shape[0],
                model_outputs['p_occupancy'].shape[0],
                model_outputs['p_freespace'].shape[0]))

        if not point_clouds.isempty():
            # Photo Consistency Loss
            normalizing_value = 1.0
            if reduction_method == 'sum':
                total_p = batch_size * self.training_scheduler.init_n_rays
                normalizing_value = 1.0 / total_p * min(
                    (sdf_freespace.nelement() + sdf_occupancy.nelement()) /
                    float(rgb_gt.size(0)), 1.0)
            self.calc_photoconsistency_loss(
                point_clouds,
                rgb_gt,
                reduction_method,
                loss,
                normalizing_value=normalizing_value)

        # Occupancy and Freespace losses
        if self.lambda_occupied > 0 or self.lambda_freespace > 0:
            normalizing_value = 1.0
            if reduction_method == 'sum':
                total_p = batch_size * self.training_scheduler.init_n_rays
                normalizing_value = 1.0 / total_p
            self.calc_sdf_mask_loss(sdf_freespace,
                                    sdf_occupancy,
                                    self.sdf_alpha,
                                    reduction_method,
                                    loss,
                                    normalizing_value=normalizing_value)

        # Eikonal loss
        # Random samples in the space
        total_p = batch_size * self.training_scheduler.init_n_rays
        space_pts = torch.empty(total_p, 3).uniform_(
            -self.model.object_bounding_sphere,
            self.model.object_bounding_sphere).to(device=device)
        # space_pts = torch.cat([model_outputs.get('p_freespace'), model_outputs.get('p_occupancy'), point_clouds.points_packed()]).detach()
        eikonal_normals = self.model.get_normals_from_grad(space_pts,
                                                           c=inputs,
                                                           requires_grad=True)
        normalizing_value = 1.0
        if reduction_method == 'sum':
            normalizing_value = 1.0 / total_p
        self.calc_eikonal_loss(eikonal_normals,
                               reduction_method,
                               loss,
                               normalizing_value=normalizing_value)

        for k, v in loss.items():
            mode = 'val' if eval_mode else 'train'
            if isinstance(v, torch.Tensor):
                self.tb_logger.add_scalar('%s/%s' % (mode, k), v.item(), it)
            else:
                self.tb_logger.add_scalar('%s/%s' % (mode, k), v, it)

        return loss if eval_mode else loss['loss']

    def ref_per_point_metric(self,
                             ref_pcl: PointClouds3D = None,
                             mode='curvature'):
        """
        Computes the metric used for sampling or weighting, e.g. curvature / loss,
        and update to pcl's features
        When mode == 'curvature', estimates the shape curvature from the point cloud using
        a local neighborhood of 12 points,
        When mode == 'loss', average the per point per view RGB loss over the entire training images
        """
        if ref_pcl is None:
            ref_pcl = self.model._points

        if (n_points := ref_pcl.points_packed().shape[0]) > 5000:
            ref_pcl = farthest_sampling(ref_pcl, 5000 / n_points)

        with autograd.no_grad():
            if mode == 'loss':
                self.model.eval()

                cameras = self.val_loader.dataset.get_cameras()
                lights = self.val_loader.dataset.get_lights()

                # project all init_points to the surface
                proj_result = self.model.projection.project_points(
                    ref_pcl,
                    self.model.decoder,
                    skip_resampling=False,
                    skip_upsampling=False,
                    sample_iters=2)
                ref_pcl = PointClouds3D(
                    proj_result['levelset_points'][proj_result['mask']].view(
                        1, -1, 3),
                    normals=proj_result['levelset_normals'][
                        proj_result['mask']].view(1, -1, 3))
                num_points2 = ref_pcl.num_points_per_cloud()
                logger_py.info(
                    '[Per Point Loss Metric] evaluating ref point cloud ({}) on all training images'
                    .format(num_points2.item()))
                assert (len(num_points2) == 1)
                from ..utils.mathHelper import RunningStat
                runStat = RunningStat(num_points2.item(),
                                      device=ref_pcl.device)

                # set max_iso_per_batch to -1
                max_iso_per_batch = self.model.max_iso_per_batch
                self.model.max_iso_per_batch = -1
                # If loss is used, transfer the computed loss in the current point cloud to the reference point cloud
                for batch in tqdm(self.val_loader):
                    data = self.process_data_dict(batch,
                                                  cameras=cameras,
                                                  lights=lights)
                    mask_img, img, cameras, lights = data['mask_img'], data[
                        'img'], data['camera'], data['light']
                    # set proj_max_iters to 0 because we already projected the points before hand
                    model_outputs = self.model(
                        mask_img,
                        img,
                        cameras,
                        mask_gt=None,
                        pixels=None,
                        inputs=None,
                        lights=lights,
                        project=True,
                        sample_iso_offsurface=False,
                    )

                    point_clouds = model_outputs['iso_pcl']
                    pixel_pred = model_outputs['iso_pixel']
                    rgb_gt = model_outputs['iso_rgb_gt']
                    sig = num_points2.item(
                    ) / self.model.object_bounding_sphere
                    dist_thres = 4 / sig
                    if not point_clouds.isempty():
                        num_points1 = point_clouds.num_points_per_cloud().sum(
                            0, keepdim=True)
                        loss = {'loss': 0.0}
                        self.calc_photoconsistency_loss(
                            point_clouds,
                            rgb_gt,
                            'none',
                            loss,
                        )
                        loss_per_point = loss['loss_rgb'] / self.lambda_rgb

                        query_points = point_clouds.points_packed()
                        dists, idxs, _ = knn_points(
                            ref_pcl.points_padded().contiguous(),
                            query_points.unsqueeze(0).contiguous(),
                            num_points2,
                            num_points1,
                            K=1,
                            return_nn=False)
                        loss_ref_point = knn_gather(
                            loss_per_point.view(1, num_points1.item(), 1),
                            idxs, num_points1).view(-1, 1)
                        mask = (dists < dist_thres) & (dists > 0)
                        runStat.add(loss_ref_point.view(-1), mask.view(-1))

                per_point_metric = runStat.mean().view(-1, 1)
                logger_py.debug(
                    '[Per Point Loss Metric] ref point cloud metric min {} max {} median {}'
                    .format(per_point_metric.min(), per_point_metric.max(),
                            per_point_metric.median()))

                self.model.max_iso_per_batch = max_iso_per_batch
            # or curvature is used
            elif mode == 'curvature':
                curvatures, _ = estimate_pointcloud_local_coord_frames(
                    ref_pcl,
                    neighborhood_size=12,
                    disambiguate_directions=False,
                    return_knn_result=False)
                # high curvature area : variance in the local frame is large
                curvatures = curvatures[..., 0] / \
                    eps_denom(curvatures[..., -1])
                per_point_metric = curvatures.view(-1, 1)

        ref_pcl = ref_pcl.update_features_(per_point_metric)

        return ref_pcl
Ejemplo n.º 18
0
def _compute_sampling_metrics(pred_points, pred_normals, gt_points, gt_normals,
                              thresholds, eps):
    """
    Compute metrics that are based on sampling points and normals:

    - L2 Chamfer distance
    - Precision at various thresholds
    - Recall at various thresholds
    - F1 score at various thresholds
    - Normal consistency (if normals are provided)
    - Absolute normal consistency (if normals are provided)

    Inputs:
        - pred_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
          for each predicted mesh
        - pred_normals: Tensor of shape (N, S, 3) giving normals of points sampled
          from the predicted mesh, or None if such normals are not available
        - gt_points: Tensor of shape (N, S, 3) giving coordinates of sampled points
          for each ground-truth mesh
        - gt_normals: Tensor of shape (N, S, 3) giving normals of points sampled from
          the ground-truth verts, or None of such normals are not available
        - thresholds: Distance thresholds to use for precision / recall / F1
        - eps: epsilon value to handle numerically unstable F1 computation

    Returns:
        - metrics: A dictionary where keys are metric names and values are Tensors of
          shape (N,) giving the value of the metric for the batch
    """
    metrics = {}
    lengths_pred = torch.full((pred_points.shape[0], ),
                              pred_points.shape[1],
                              dtype=torch.int64,
                              device=pred_points.device)
    lengths_gt = torch.full((gt_points.shape[0], ),
                            gt_points.shape[1],
                            dtype=torch.int64,
                            device=gt_points.device)

    # For each predicted point, find its neareast-neighbor GT point
    knn_pred = knn_points(pred_points,
                          gt_points,
                          lengths1=lengths_pred,
                          lengths2=lengths_gt,
                          K=1)
    # Compute L1 and L2 distances between each pred point and its nearest GT
    pred_to_gt_dists2 = knn_pred.dists[..., 0]  # (N, S)
    pred_to_gt_dists = pred_to_gt_dists2.sqrt()  # (N, S)
    if gt_normals is not None:
        pred_normals_near = knn_gather(gt_normals, knn_pred.idx,
                                       lengths_gt)[..., 0, :]  # (N, S, 3)
    else:
        pred_normals_near = None

    # For each GT point, find its nearest-neighbor predicted point
    knn_gt = knn_points(gt_points,
                        pred_points,
                        lengths1=lengths_gt,
                        lengths2=lengths_pred,
                        K=1)
    # Compute L1 and L2 dists between each GT point and its nearest pred point
    gt_to_pred_dists2 = knn_gt.dists[..., 0]  # (N, S)
    gt_to_pred_dists = gt_to_pred_dists2.sqrt()  # (N, S)

    if pred_normals is not None:
        gt_normals_near = knn_gather(pred_normals, knn_gt.idx,
                                     lengths_pred)[..., 0, :]  # (N, S, 3)
    else:
        gt_normals_near = None

    # Compute L2 chamfer distances
    chamfer_l2 = pred_to_gt_dists2.mean(dim=1) + gt_to_pred_dists2.mean(dim=1)
    metrics["Chamfer-L2"] = chamfer_l2

    # Compute normal consistency and absolute normal consistance only if
    # we actually got normals for both meshes
    if pred_normals is not None and gt_normals is not None:
        pred_to_gt_cos = F.cosine_similarity(pred_normals,
                                             pred_normals_near,
                                             dim=2)
        gt_to_pred_cos = F.cosine_similarity(gt_normals,
                                             gt_normals_near,
                                             dim=2)

        pred_to_gt_cos_sim = pred_to_gt_cos.mean(dim=1)
        pred_to_gt_abs_cos_sim = pred_to_gt_cos.abs().mean(dim=1)
        gt_to_pred_cos_sim = gt_to_pred_cos.mean(dim=1)
        gt_to_pred_abs_cos_sim = gt_to_pred_cos.abs().mean(dim=1)
        normal_dist = 0.5 * (pred_to_gt_cos_sim + gt_to_pred_cos_sim)
        abs_normal_dist = 0.5 * (pred_to_gt_abs_cos_sim +
                                 gt_to_pred_abs_cos_sim)
        metrics["NormalConsistency"] = normal_dist
        metrics["AbsNormalConsistency"] = abs_normal_dist

    # Compute precision, recall, and F1 based on L2 distances
    for t in thresholds:
        precision = 100.0 * (pred_to_gt_dists < t).float().mean(dim=1)
        recall = 100.0 * (gt_to_pred_dists < t).float().mean(dim=1)
        f1 = (2.0 * precision * recall) / (precision + recall + eps)
        metrics["Precision@%f" % t] = precision
        metrics["Recall@%f" % t] = recall
        metrics["F1@%f" % t] = f1

    # Move all metrics to CPU
    metrics = {k: v.cpu() for k, v in metrics.items()}
    return metrics