コード例 #1
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def _get_kappa_adv(adv_pc, ori_pc, ori_normal, k=2):
    b, _, n = adv_pc.size()
    # compute knn between advPC and oriPC to get normal n_p
    #intra_dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1)
    #intra_idx = torch.topk(intra_dis, 1, dim=2, largest=False, sorted=True)[1]
    #normal = torch.gather(ori_normal, 2, intra_idx.view(b,1,n).expand(b,3,n))
    intra_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           ori_pc.permute(0, 2, 1),
                           K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    normal = knn_gather(ori_normal.permute(0, 2, 1), intra_KNN.idx).permute(
        0, 3, 1, 2).squeeze(3).contiguous()  # [b, 3, n]

    # compute knn between advPC and itself to get \|q-p\|_2
    #inter_dis = ((adv_pc.unsqueeze(3) - adv_pc.unsqueeze(2))**2).sum(1)
    #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous()
    #nn_pts = torch.gather(adv_pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k)
    inter_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           adv_pc.permute(0, 2, 1),
                           K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
    nn_pts = knn_gather(adv_pc.permute(0, 2, 1), inter_KNN.idx).permute(
        0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]
    vectors = nn_pts - adv_pc.unsqueeze(3)
    vectors = _normalize(vectors)

    return torch.abs(
        (vectors *
         normal.unsqueeze(3)).sum(1)).mean(2), normal  # [b, n], [b, 3, n]
コード例 #2
0
def eval_batch(points_pred, points_gt):
    d_1, _, _ = knn_points(points_pred, points_gt)
    d_2, _, _ = knn_points(points_pred, points_gt)
    err_1, _ = d_1.squeeze(-1).max(dim=1)
    err_2, _ = d_2.squeeze(-1).max(dim=1)
    err = torch.cat((err_1.unsqueeze(-1), err_2.unsqueeze(-1)), dim=-1)
    e, _ = err.max(dim=1)
    return e
コード例 #3
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def chamfer_loss(adv_pc, ori_pc):
    # Chamfer distance (two sides)
    #intra_dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1)
    #dis_loss = intra_dis.min(2)[0].mean(1) + intra_dis.min(1)[0].mean(1)
    adv_KNN = knn_points(adv_pc.permute(0, 2, 1), ori_pc.permute(0, 2, 1),
                         K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    ori_KNN = knn_points(ori_pc.permute(0, 2, 1), adv_pc.permute(0, 2, 1),
                         K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    dis_loss = adv_KNN.dists.contiguous().squeeze(-1).mean(
        -1) + ori_KNN.dists.contiguous().squeeze(-1).mean(-1)  #[b]
    return dis_loss
コード例 #4
0
ファイル: utils.py プロジェクト: murnanedaniel/end-to-end
def build_edges(spatial,
                r_max,
                k_max,
                return_indices=False,
                target_spatial=None):

    if k_max > 200:
        if device == "cuda":
            res = faiss.StandardGpuResources()
            if target_spatial is None:
                D, I = faiss.knn_gpu(res, spatial, spatial, k_max)
            else:
                D, I = faiss.knn_gpu(res, spatial, target_spatial, k_max)
        elif device == "cpu":
            index = faiss.IndexFlatL2(spatial.shape[1])
            index.add(spatial)
            if target_spatial is None:
                D, I = index.search(spatial, k_max)
            else:
                D, I = index.search(target_spatial, k_max)

    else:
        if target_spatial is None:
            knn_object = ops.knn_points(spatial.unsqueeze(0),
                                        spatial.unsqueeze(0),
                                        K=k_max,
                                        return_sorted=False)
        else:
            knn_object = ops.knn_points(
                spatial.unsqueeze(0),
                target_spatial.unsqueeze(0),
                K=k_max,
                return_sorted=False,
            )
        I = knn_object.idx[0]
        D = knn_object.dists[0]

    # Overlay the "source" hit ID onto each neighbour ID (this is necessary as the FAISS algo does some shortcuts)
    ind = torch.Tensor.repeat(torch.arange(I.shape[0], device=device),
                              (I.shape[1], 1), 1).T
    edge_list = torch.stack([ind[D <= r_max**2], I[D <= r_max**2]])

    # Remove self-loops
    edge_list = edge_list[:, edge_list[0] != edge_list[1]]

    if return_indices:
        return edge_list, D, I, ind
    else:
        return edge_list
コード例 #5
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def hausdorff_loss(adv_pc, ori_pc):
    #dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1)
    #hd_loss = torch.max(torch.min(dis, dim=2)[0], dim=1)[0]
    adv_KNN = knn_points(adv_pc.permute(0, 2, 1), ori_pc.permute(0, 2, 1),
                         K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    hd_loss = adv_KNN.dists.contiguous().squeeze(-1).max(-1)[0]  #[b]
    return hd_loss
コード例 #6
0
    def forward(self, x, nsample, xyz):
        if nsample == 1:
            sampled_idx = None
            sampled_xyz = torch.mean(xyz, dim=-1, keepdim=True)
            _, sampled_idx, sampled_xyz = ops.knn_points(sampled_xyz.transpose(
                1, 2),
                                                         xyz.transpose(1, 2),
                                                         return_nn=True)
            sampled_xyz = sampled_xyz.squeeze(2)
            sampled_idx = sampled_idx.squeeze(1)
        else:
            sampled_idx, sampled_xyz = furthest_point_sample(xyz,
                                                             nsample,
                                                             NCHW=True)

        sampled_x = gather_points(x, sampled_idx)
        for i, mlp in enumerate(self.mlps):
            if i == 0:
                y, idx = self.get_local_graph(sampled_x, x, k=self.k)
                sampled_x = sampled_x.unsqueeze(-1).expand(-1, -1, -1, self.k)
                y = torch.cat([nn.functional.relu_(mlp(y)), sampled_x], dim=1)
            elif i == (self.n - 1):
                y = torch.cat([mlp(y), y], dim=1)
            else:
                y = torch.cat([nn.functional.relu_(mlp(y)), y], dim=1)

        y, _ = torch.max(y, dim=-1)
        return y, sampled_xyz, sampled_idx
コード例 #7
0
    def get_local_graph(self, query, x, k, idx=None):
        """Construct edge feature [x, NN_i - x] for each point x
        :param
            x: (B, C, N)
            k: int
            idx: (B, N, k)
        :return
            edge features: (B, C, N, k)
        """
        if idx is None:
            # BCN(K+1), BN(K+1)
            idx, knn_point = ops.knn_points(query.transpose(1, 2),
                                            x.transpose(1, 2),
                                            K=k + 1,
                                            return_nn=True)
            idx = idx[:, :, 1:]
            knn_point = knn_point.permute(0, 2, 3, 1)
            knn_point = knn_point[:, :, :, 1:]

        neighbor_center = torch.unsqueeze(query, dim=-1)
        neighbor_center = neighbor_center.expand_as(knn_point)

        edge_feature = torch.cat(
            [neighbor_center, knn_point - neighbor_center], dim=1)
        return edge_feature, idx
コード例 #8
0
def pointUniformLaplacian(points, knn_idx=None, nn_size=3):
    """
    Args:
        points: (B, N, 3)
        knn_idx: (B, N, K)
    Returns:
        laplacian: (B, N, 1)
    """
    batch_size, num_points, _ = points.shape
    if knn_idx is None:
        # find neighborhood, (B,N,K,3), (B,N,K)
        _, knn_idx, group_points = ops.knn_points(points, points, K=nn_size+1, return_nn=True)
        knn_idx = knn_idx[:, :, 1:]
        group_points = group_points[:, :, 1:, :]
    else:
        points_expanded = points.unsqueeze(dim=1).expand(
            (-1, num_points, -1, -1))
        # BxNxk -> BxNxNxC
        index_batch_expanded = knn_idx.unsqueeze(dim=-1).expand(
            (-1, -1, -1, points.size(-1)))
        # BxMxkxC
        group_points = torch.gather(points_expanded, 2, index_batch_expanded)

    lap = -torch.sum(group_points, dim=2)/knn_idx.shape[2] + points
    return lap, knn_idx
コード例 #9
0
ファイル: losses.py プロジェクト: PeterZhouSZ/iso-points
    def _build_knn(self, point_clouds):
        """
        search for KNN again set knn_tree and knn_mask attributes
        TODO(yifan): use a real Kd_tree library to be able to store the data tree and
        query at each forward pass?
        """
        # Find local neighborhood to compute weights
        with torch.autograd.enable_grad():
            points_padded = point_clouds.points_padded()

        lengths = point_clouds.num_points_per_cloud()
        knn_result = ops.knn_points(points_padded,
                                    points_padded,
                                    lengths,
                                    lengths,
                                    K=self.knn_k,
                                    return_nn=True)
        self.knn_mask = torch.full(knn_result.idx.shape,
                                   False,
                                   dtype=torch.bool,
                                   device=points_padded.device)
        # valid knn result
        for b in range(self.knn_mask.shape[0]):
            self.knn_mask[
                b, :lengths[b], :min(self.knn_k, lengths[b].item())] = True
            assert (torch.all(knn_result.dists[b][~self.knn_mask[b]] == 0))
        self.knn_tree = KNN(knn=knn_result.knn[:, :, 1:, :],
                            dists=knn_result.dists[:, :, 1:],
                            idx=knn_result.idx[:, :, 1:])
        self.knn_mask = self.knn_mask[:, :, 1:]
        assert (self.knn_mask.shape == self.knn_tree.dists.shape)
コード例 #10
0
    def forward(self, points, knn_idx=None):
        batchSize, PN, _ = points.shape
        if knn_idx is None:
            distance2, knn_idx, knn_points = ops.knn_points(points,
                                                            points,
                                                            K=self.nn_size + 1,
                                                            return_nn=True)
            knn_points = knn_points[:, :, 1:, :].contiguous().detach()
            knn_idx = knn_idx[:, :, 1:].contiguous()
        else:
            knn_points = torch.gather(
                points.unsqueeze(1).expand(-1, PN, -1, -1), 2,
                knn_idx.unsqueeze(-1).expand(-1, -1, -1, points.shape[-1]))

        knn_v = knn_points - points.unsqueeze(dim=2)
        distance2 = torch.sum(knn_v * knn_v, dim=-1)
        loss = 1 / torch.sqrt(distance2 + 1e-4)
        loss = torch.where(distance2 < self.radius2, loss,
                           torch.zeros_like(loss))
        if self.reduction == "mean":
            return loss.mean()
        elif self.reduction == "max":
            return torch.mean(torch.max(loss, dim=-1)[0])
        elif self.reduction == "sum":
            return loss.mean(torch.sum(loss, dim=-1))
        elif self.reduction == "none":
            return loss
        else:
            raise NotImplementedError
        return loss
コード例 #11
0
 def forward(self, points_ref, points):
     """
     point1: (B,N,D) ref points (where connectivity is computed)
     point2: (B,N,D) pred points, uses connectivity of point1
     """
     # find neighborhood, (B,N,K,3), (B,N,K), (B,N,K)
     _, knn_idx, group_points_ref = ops.knn_points(points_ref,
                                                   points_ref,
                                                   K=self.nn_size + 1,
                                                   return_nn=True)
     knn_idx = knn_idx[:, :, 1:]
     group_points_ref = group_points_ref[:, :, 1:, :]
     dist_ref = torch.norm(group_points_ref - points_ref.unsqueeze(2),
                           dim=-1,
                           p=2)
     group_points = torch.gather(
         points.unsqueeze(1).expand(-1, knn_idx.shape[1], -1, -1), 2,
         knn_idx.unsqueeze(-1).expand(-1, -1, -1, points.shape[-1]))
     dist = torch.norm(group_points - points.unsqueeze(2), dim=-1, p=2)
     stretch = torch.max(dist / (dist_ref + 1e-10) - 1,
                         torch.zeros_like(dist))
     if self.reduction == "mean":
         return torch.mean(stretch)
     elif self.reduction == "sum":
         return torch.mean(torch.sum(stretch, dim=-1))
     elif self.reduction == "none":
         return stretch
     elif self.reduction == "max":
         return torch.mean(torch.max(stretch, dim=-1)[0])
     else:
         raise NotImplementedError
コード例 #12
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def pseudo_chamfer_loss(adv_pc, ori_pc):
    # Chamfer pseudo distance (one side)
    #intra_dis = ((adv_pc.unsqueeze(3) - ori_pc.unsqueeze(2))**2).sum(1) #b*n*n
    #dis_loss = intra_dis.min(2)[0].mean(1)
    adv_KNN = knn_points(adv_pc.permute(0, 2, 1), ori_pc.permute(0, 2, 1),
                         K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    dis_loss = adv_KNN.dists.contiguous().squeeze(-1).mean(-1)  #[b]
    return dis_loss
コード例 #13
0
ファイル: rasterizer.py プロジェクト: yifita/DSS
    def _compute_global_Vrk(self, pointclouds, refresh=True, **kwargs):
        """
        determine variance scaler used in globally (see _compute_isotropic_Vrk)
        Args:
            pointclouds: pointclouds in object coorindates
        Returns:
            h_k: scaler
            S_k: local frame
        """
        if not refresh and self._Vrk_h is not None:
            h_k = self._Vrk_h
        else:
            # compute average density
            with torch.autograd.enable_grad():
                pts_world = pointclouds.points_padded()

            num_points_per_cloud = pointclouds.num_points_per_cloud()
            if self.frnn_radius <= 0:
                # use knn here
                # logger_py.info("vrk knn points")
                sq_dist, _, _ = ops3d.knn_points(pts_world,
                                                 pts_world,
                                                 num_points_per_cloud,
                                                 num_points_per_cloud,
                                                 K=7)
            else:
                sq_dist, _, _, _ = frnn.frnn_grid_points(pts_world,
                                                         pts_world,
                                                         num_points_per_cloud,
                                                         num_points_per_cloud,
                                                         K=7,
                                                         r=self.frnn_radius)
            # logger_py.info("frnn and knn dist close: {}".format(torch.allclose(sq_dist, sq_dist2)))
            sq_dist = sq_dist[:, :, 1:]
            # knn search is unreliable, set sq_dist manually
            sq_dist[num_points_per_cloud < 7] = 1e-3
            h_k = 0.5 * sq_dist.max(dim=-1, keepdim=True)[0]
            # prevent some outlier rendered be too large, or too small
            h_k = h_k.mean(dim=1, keepdim=True).clamp(5e-5, 1e-3)
            Vrk_h = gather_batch_to_packed(h_k,
                                           pointclouds.packed_to_cloud_idx())

        # Sk, a transformation from 2D local surface frame to 3D world frame
        # Because isometry, two axis are equivalent, we can simply
        # find two 3d vectors perpendicular to the point normals
        # (totalP, 2, 3)
        with torch.autograd.enable_grad():
            normals = pointclouds.normals_packed()

        u0 = F.normalize(torch.cross(normals,
                                     normals + torch.rand_like(normals)),
                         dim=-1)
        u1 = F.normalize(torch.cross(normals, u0), dim=-1)
        Sk = torch.stack([u0, u1], dim=1)
        Vrk = Vrk_h.view(-1, 1, 1) * Sk.transpose(1, 2) @ Sk
        return Vrk, Sk
コード例 #14
0
ファイル: frnn_whole.py プロジェクト: lxxue/FRNN
 def output():
     dists, idxs, nn = knn_points(points1,
                                  points2,
                                  lengths1,
                                  lengths2,
                                  K,
                                  version=-1,
                                  return_nn=False,
                                  return_sorted=True)
     torch.cuda.synchronize()
コード例 #15
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def uniform_loss(adv_pc,
                 percentages=[0.004, 0.006, 0.008, 0.010, 0.012],
                 radius=1.0,
                 k=2):
    if adv_pc.size(1) == 3:
        adv_pc = adv_pc.permute(0, 2, 1).contiguous()
    b, n, _ = adv_pc.size()
    npoint = int(n * 0.05)
    for p in percentages:
        p = p * 4
        nsample = int(n * p)
        r = math.sqrt(p * radius)
        disk_area = math.pi * (radius**2) * p / nsample
        expect_len = torch.sqrt(torch.Tensor([disk_area])).cuda()

        adv_pc_flipped = adv_pc.transpose(1, 2).contiguous()
        new_xyz = pointnet2_utils.gather_operation(
            adv_pc_flipped,
            pointnet2_utils.furthest_point_sample(adv_pc, npoint)).transpose(
                1, 2).contiguous()  # (batch_size, npoint, 3)

        idx = pointnet2_utils.ball_query(
            r, nsample, adv_pc, new_xyz)  #(batch_size, npoint, nsample)

        grouped_pcd = pointnet2_utils.grouping_operation(
            adv_pc_flipped,
            idx).permute(0, 2, 3,
                         1).contiguous()  # (batch_size, npoint, nsample, 3)
        grouped_pcd = torch.cat(torch.unbind(grouped_pcd, axis=1), axis=0)

        grouped_pcd = grouped_pcd.permute(0, 2, 1).contiguous()
        #dis = torch.sqrt(((grouped_pcd.unsqueeze(3) - grouped_pcd.unsqueeze(2))**2).sum(1)+1e-12) # (batch_size*npoint, nsample, nsample)
        #dists, _ = torch.topk(dis, k+1, dim=2, largest=False, sorted=True) # (batch_size*npoint, nsample, k+1)
        inter_KNN = knn_points(grouped_pcd.permute(0, 2, 1),
                               grouped_pcd.permute(0, 2, 1),
                               K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]

        uniform_dis = inter_KNN.dists[:, :, 1:].contiguous()
        uniform_dis = torch.sqrt(torch.abs(uniform_dis) + 1e-12)
        uniform_dis = uniform_dis.mean(axis=[-1])
        uniform_dis = (uniform_dis - expect_len)**2 / (expect_len + 1e-12)
        uniform_dis = torch.reshape(uniform_dis, [-1])

        mean = uniform_dis.mean()
        mean = mean * math.pow(p * 100, 2)

        #nothing 4
        try:
            loss = loss + mean
        except:
            loss = mean
    return loss / len(percentages)
コード例 #16
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def _get_kappa_ori(pc, normal, k=2):
    b, _, n = pc.size()
    #inter_dis = ((pc.unsqueeze(3) - pc.unsqueeze(2))**2).sum(1)
    #inter_idx = torch.topk(inter_dis, k+1, dim=2, largest=False, sorted=True)[1][:, :, 1:].contiguous()
    #nn_pts = torch.gather(pc, 2, inter_idx.view(b,1,n*k).expand(b,3,n*k)).view(b,3,n,k)
    inter_KNN = knn_points(pc.permute(0, 2, 1), pc.permute(0, 2, 1),
                           K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]
    nn_pts = knn_gather(pc.permute(0, 2, 1), inter_KNN.idx).permute(
        0, 3, 1, 2)[:, :, :, 1:].contiguous()  # [b, 3, n ,k]
    vectors = nn_pts - pc.unsqueeze(3)
    vectors = _normalize(vectors)

    return torch.abs((vectors * normal.unsqueeze(3)).sum(1)).mean(2)  # [b, n]
コード例 #17
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def curvature_loss(adv_pc, ori_pc, adv_kappa, ori_kappa, k=2):
    b, _, n = adv_pc.size()

    # intra_dis = ((input_curr_iter.unsqueeze(3) - pc_ori.unsqueeze(2))**2).sum(1)
    # intra_idx = torch.topk(intra_dis, 1, dim=2, largest=False, sorted=True)[1]
    # knn_theta_normal = torch.gather(theta_normal, 1, intra_idx.view(b,n).expand(b,n))
    # curv_loss = ((curv_loss - knn_theta_normal)**2).mean(-1)

    intra_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           ori_pc.permute(0, 2, 1),
                           K=1)  #[dists:[b,n,1], idx:[b,n,1]]
    onenn_ori_kappa = torch.gather(
        ori_kappa, 1, intra_KNN.idx.squeeze(-1)).contiguous()  # [b, n]

    curv_loss = ((adv_kappa - onenn_ori_kappa)**2).mean(-1)

    return curv_loss
コード例 #18
0
ファイル: loss_utils.py プロジェクト: Gorilla-Lab-SCUT/GeoA3
def kNN_smoothing_loss(adv_pc, k, threshold_coef=1.05):
    b, _, n = adv_pc.size()
    #dis = ((adv_pc.unsqueeze(3) - adv_pc.unsqueeze(2))**2).sum(1) #[b,n,n]
    #dis, idx = torch.topk(dis, k+1, dim=2, largest=False, sorted=True)#[b,n,k+1]
    inter_KNN = knn_points(adv_pc.permute(0, 2, 1),
                           adv_pc.permute(0, 2, 1),
                           K=k + 1)  #[dists:[b,n,k+1], idx:[b,n,k+1]]

    knn_dis = inter_KNN.dists[:, :, 1:].contiguous().mean(-1)  #[b,n]
    knn_dis_mean = knn_dis.mean(-1)  #[b]
    knn_dis_std = knn_dis.std(-1)  #[b]
    threshold = knn_dis_mean + threshold_coef * knn_dis_std  #[b]

    condition = torch.gt(knn_dis, threshold.unsqueeze(1)).float()  #[b,n]
    dis_mean = knn_dis * condition  #[b,n]

    return dis_mean.mean(1)  #[b]
コード例 #19
0
ファイル: layers.py プロジェクト: FanLu97/PointINet
def knn_group_withI(points1, points2, intensity2, k):
    '''
    Input:
        points1: [B,3,N]
        points2: [B,3,N]
        intensity2: [B,1,N]
    '''
    points1 = points1.permute(0,2,1).contiguous()
    points2 = points2.permute(0,2,1).contiguous()
    _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True)
    points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1) # [B,M,k,3]
    grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True)
    grouped_features = knn_gather(intensity2.permute(0,2,1), nn_idx) # [B,M,k,1]
    new_features = torch.cat([points_resi, grouped_dist], dim=-1)
    
    # [B,5,M,k], [B,3,M,k], [B,1,M,k]
    return new_features.permute(0,3,1,2).contiguous(), \
        nn.permute(0,3,1,2).contiguous(), \
        grouped_features.permute(0,3,1,2).contiguous()
コード例 #20
0
def batch_normals(points, base=None, nn_size=20, NCHW=True, idx=None):
    """
    compute normals vectors for batched points [B, C, M]
    If base is given, compute the normals of points using the neighborhood in base
    The direction of normal could flip.

    Args:
        points:  (B,C,M)
        base:    (B,C,N)
        idx      (B,M,nn_size)
    Returns:
        normals: (B,C,M)
    """
    if base is None:
        base = points

    if NCHW:
        points = points.transpose(2, 1).contiguous()
        base = base.transpose(2, 1).contiguous()

    assert(nn_size < base.shape[1])
    batch_size, M, C = points.shape
    # B,M,k,C
    if idx is None:
        _, idx, grouped_points = ops.knn_points(points, base, K=nn_size, return_nn=True)
    else:
        grouped_points = torch.gather(base.unsqueeze(1).expand(-1,M,-1,-1), 2, idx.unsqueeze(-1).expand(-1,-1,-1,C))
    group_center = torch.mean(grouped_points, dim=2, keepdim=True)
    points = grouped_points - group_center
    allpoints = points.view(-1, nn_size, C).contiguous()
    # MB,C,k
    U, S, V = batch_svd(allpoints)
    # V is MBxCxC, last_u MBxC
    normals = V[:, :, -1]
    normals = normals.view(batch_size, M, C)
    if NCHW:
        normals = normals.transpose(1, 2)
    return normals, idx
コード例 #21
0
 def forward(self, points_ref, points):
     """
     point1: (B,N,D) ref points (where connectivity is computed)
     point2: (B,N,D) pred points, uses connectivity of point1
     """
     # find neighborhood, (B,N,K,3), (B,N,K)
     _, knn_idx, group_points = ops.knn_points(points_ref,
                                               points_ref,
                                               K=self.nn_size + 1,
                                               return_nn=True)
     knn_idx = knn_idx[:, :, 1:]
     group_points = group_points[:, :, 1:, :]
     dist_ref = torch.norm(group_points - points_ref.unsqueeze(2),
                           dim=-1,
                           p=2)
     # dist_ref = torch.sqrt(dist_ref)
     # B,N,K,D
     group_points = torch.gather(
         points.unsqueeze(1).expand(-1, knn_idx.shape[1], -1, -1), 2,
         knn_idx.unsqueeze(-1).expand(-1, -1, -1, points.shape[-1]))
     dist = torch.norm(group_points - points.unsqueeze(2), dim=-1, p=2)
     # print(group_points, group_points2)
     return self.metric(dist_ref, dist)
コード例 #22
0
def estimate_normal_via_ori_normal(pc_adv, pc_ori, normal_ori, k):
    # pc_adv, pc_ori, normal_ori : [b,3,n]
    b, _, n = pc_adv.size()
    intra_KNN = knn_points(pc_adv.permute(0, 2, 1),
                           pc_ori.permute(0, 2, 1),
                           K=k)  #[dists:[b,n,k], idx:[b,n,k]]
    inter_value = intra_KNN.dists[:, :, 0].contiguous()
    inter_idx = intra_KNN.idx.permute(0, 2, 1).contiguous()
    normal_pts = knn_gather(normal_ori.permute(
        0, 2, 1), intra_KNN.idx).permute(0, 3, 1,
                                         2).contiguous()  # [b, 3, n ,k]

    normal_pts_avg = normal_pts.mean(dim=-1)
    normal_pts_avg = normal_pts_avg / (normal_pts_avg.norm(dim=1) + 1e-12)

    # If the points are not modified (distance = 0), use the normal directly from the original
    # one. Otherwise, use the mean of the normals of the k-nearest points.
    normal_ori_select = normal_pts[:, :, :, 0]
    condition = (inter_value < 1e-6).unsqueeze(1).expand_as(normal_ori_select)
    normals_estimated = torch.where(condition, normal_ori_select,
                                    normal_pts_avg)

    return normals_estimated
コード例 #23
0
ファイル: layers.py プロジェクト: FanLu97/PointINet
    def knn_group(self, points1, points2, features2, k):
        '''
        For each point in points1, query kNN points/features in points2/features2
        Input:
            points1: [B,3,N]
            points2: [B,3,N]
            features2: [B,C,N]
        Output:
            new_features: [B,4,N]
            nn: [B,3,N]
            grouped_features: [B,C,N]
        '''
        points1 = points1.permute(0,2,1).contiguous()
        points2 = points2.permute(0,2,1).contiguous()
        _, nn_idx, nn = knn_points(points1, points2, K=k, return_nn=True)
        points_resi = nn - points1.unsqueeze(2).repeat(1,1,k,1)
        grouped_dist = torch.norm(points_resi, dim=-1, keepdim=True)
        grouped_features = knn_gather(features2.permute(0,2,1), nn_idx)
        new_features = torch.cat([points_resi, grouped_dist], dim=-1)

        return new_features.permute(0,3,1,2).contiguous(),\
            nn.permute(0,3,1,2).contiguous(),\
            grouped_features.permute(0,3,1,2).contiguous()
コード例 #24
0
ファイル: c3d_loss_knn.py プロジェクト: minghanz/c3d
def knn_pcl(pcl_1, pcl_2, n_neighbors=1, return_nn=False):
    dists, idxs, pcl_knn_to_1 = knn_points(pcl_1.points_padded(), pcl_2.points_padded(), pcl_1.num_points_per_cloud(), pcl_2.num_points_per_cloud(), K=n_neighbors, return_nn=return_nn)
    return dists, idxs, pcl_knn_to_1
コード例 #25
0
def init_boundary_volume(
    batch_size: int,
    volume_size: Tuple[int, int, int],
    border_offset: int = 2,
    shape: str = "cube",
    volume_translation: torch.Tensor = ZERO_TRANSLATION,
):
    """
    Generate a volume with sides colored with distinct colors.
    """

    device = torch.device("cuda")

    # first center the volume for the purpose of generating the canonical shape
    volume_translation_tmp = (0.0, 0.0, 0.0)

    # set the voxel size to 1 / (volume_size-1)
    volume_voxel_size = 1 / (volume_size[0] - 1.0)

    # colors of the sides of the cube
    clr_sides = torch.tensor(
        [
            [1.0, 1.0, 1.0],
            [1.0, 0.0, 0.0],
            [1.0, 0.0, 1.0],
            [1.0, 1.0, 0.0],
            [0.0, 1.0, 0.0],
            [0.0, 1.0, 1.0],
        ],
        dtype=torch.float32,
        device=device,
    )

    # get the coord grid of the volume
    coord_grid = Volumes(
        densities=torch.zeros(1, 1, *volume_size, device=device),
        voxel_size=volume_voxel_size,
        volume_translation=volume_translation_tmp,
    ).get_coord_grid()[0]

    # extract the boundary points and their colors of the cube
    if shape == "cube":
        boundary_points, boundary_colors = [], []
        for side, clr_side in enumerate(clr_sides):
            first = side % 2
            dim = side // 2
            slices = [slice(border_offset, -border_offset, 1)] * 3
            slices[dim] = int(border_offset * (2 * first - 1))
            slices.append(slice(0, 3, 1))
            boundary_points_ = coord_grid[slices].reshape(-1, 3)
            boundary_points.append(boundary_points_)
            boundary_colors.append(clr_side[None].expand_as(boundary_points_))
        # set the internal part of the volume to be completely opaque
        volume_densities = torch.zeros(*volume_size, device=device)
        volume_densities[[slice(border_offset, -border_offset, 1)] * 3] = 1.0
        boundary_points, boundary_colors = [
            torch.cat(p, dim=0) for p in [boundary_points, boundary_colors]
        ]
        # color the volume voxels with the nearest boundary points' color
        _, idx, _ = knn_points(coord_grid.view(1, -1, 3),
                               boundary_points.view(1, -1, 3))
        volume_colors = (boundary_colors[idx.view(-1)].view(*volume_size,
                                                            3).permute(
                                                                3, 0, 1, 2))

    elif shape == "sphere":
        # set all voxels within a certain distance from the origin to be opaque
        volume_densities = (coord_grid.norm(dim=-1) <=
                            0.5 * volume_voxel_size *
                            (volume_size[0] - border_offset)).float()
        # color each voxel with the standrd spherical color
        volume_colors = (
            (torch.nn.functional.normalize(coord_grid, dim=-1) + 1.0) *
            0.5).permute(3, 0, 1, 2)

    else:
        raise ValueError(shape)

    volume_voxel_size = torch.ones(
        (batch_size, 1), device=device) * volume_voxel_size
    volume_translation = volume_translation.expand(batch_size, 3)
    volumes = Volumes(
        densities=volume_densities[None, None].expand(batch_size, 1,
                                                      *volume_size),
        features=volume_colors[None].expand(batch_size, 3, *volume_size),
        voxel_size=volume_voxel_size,
        volume_translation=volume_translation,
    )

    return volumes, volume_voxel_size, volume_translation
コード例 #26
0
def iterative_closest_point(
    X: Union[torch.Tensor, "Pointclouds"],
    Y: Union[torch.Tensor, "Pointclouds"],
    init_transform: Optional[SimilarityTransform] = None,
    max_iterations: int = 100,
    relative_rmse_thr: float = 1e-6,
    estimate_scale: bool = False,
    allow_reflection: bool = False,
    verbose: bool = False,
) -> ICPSolution:
    """
    Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
    a similarity transformation (rotation `R`, translation `T`, and
    optionally scale `s`) between two given differently-sized sets of
    `d`-dimensional points `X` and `Y`, such that:

    `s[i] X[i] R[i] + T[i] = Y[NN[i]]`,

    for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
    for the indices of nearest neighbors from `Y` to each point in `X`.
    Note, however, that the solution is only a local optimum.

    Args:
        **X**: Batch of `d`-dimensional points
            of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
        **Y**: Batch of `d`-dimensional points
            of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
        **init_transform**: A named-tuple `SimilarityTransform` of tensors
            `R`, `T, `s`, where `R` is a batch of orthonormal matrices of
            shape `(minibatch, d, d)`, `T` is a batch of translations
            of shape `(minibatch, d)` and `s` is a batch of scaling factors
            of shape `(minibatch,)`.
        **max_iterations**: The maximum number of ICP iterations.
        **relative_rmse_thr**: A threshold on the relative root mean squared error
            used to terminate the algorithm.
        **estimate_scale**: If `True`, also estimates a scaling component `s`
            of the transformation. Otherwise assumes the identity
            scale and returns a tensor of ones.
        **allow_reflection**: If `True`, allows the algorithm to return `R`
            which is orthonormal but has determinant==-1.
        **verbose**: If `True`, prints status messages during each ICP iteration.

    Returns:
        A named tuple `ICPSolution` with the following fields:
        **converged**: A boolean flag denoting whether the algorithm converged
            successfully (=`True`) or not (=`False`).
        **rmse**: Attained root mean squared error after termination of ICP.
        **Xt**: The point cloud `X` transformed with the final transformation
            (`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
            instance of `Pointclouds`, otherwise returns `torch.Tensor`.
        **RTs**: A named tuple `SimilarityTransform` containing
        a batch of similarity transforms with fields:
            **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
            **T**: Batch of translations of shape `(minibatch, d)`.
            **s**: batch of scaling factors of shape `(minibatch, )`.
        **t_history**: A list of named tuples `SimilarityTransform`
            the transformation parameters after each ICP iteration.

    References:
        [1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
        [2] https://en.wikipedia.org/wiki/Iterative_closest_point
    """

    # make sure we convert input Pointclouds structures to
    # padded tensors of shape (N, P, 3)
    Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
    Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)

    b, size_X, dim = Xt.shape

    if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
        raise ValueError(
            "Point sets X and Y have to have the same "
            + "number of batches and data dimensions."
        )

    if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
        num_points_Y != num_points_X
    ).any():
        # we have a heterogeneous input (e.g. because X/Y is
        # an instance of Pointclouds)
        mask_X = (
            torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
            < num_points_X[:, None]
        ).type_as(Xt)
    else:
        mask_X = Xt.new_ones(b, size_X)

    # clone the initial point cloud
    Xt_init = Xt.clone()

    if init_transform is not None:
        # parse the initial transform from the input and apply to Xt
        try:
            R, T, s = init_transform
            assert (
                R.shape == torch.Size((b, dim, dim))
                and T.shape == torch.Size((b, dim))
                and s.shape == torch.Size((b,))
            )
        except Exception:
            raise ValueError(
                "The initial transformation init_transform has to be "
                "a named tuple SimilarityTransform with elements (R, T, s). "
                "R are dim x dim orthonormal matrices of shape "
                "(minibatch, dim, dim), T is a batch of dim-dimensional "
                "translations of shape (minibatch, dim) and s is a batch "
                "of scalars of shape (minibatch,)."
            )
        # apply the init transform to the input point cloud
        Xt = _apply_similarity_transform(Xt, R, T, s)
    else:
        # initialize the transformation with identity
        R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
        T = Xt.new_zeros((b, dim))
        s = Xt.new_ones(b)

    prev_rmse = None
    rmse = None
    iteration = -1
    converged = False

    # initialize the transformation history
    t_history = []

    # the main loop over ICP iterations
    for iteration in range(max_iterations):
        Xt_nn_points = knn_points(
            Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
        ).knn[:, :, 0, :]

        # get the alignment of the nearest neighbors from Yt with Xt_init
        R, T, s = corresponding_points_alignment(
            Xt_init,
            Xt_nn_points,
            weights=mask_X,
            estimate_scale=estimate_scale,
            allow_reflection=allow_reflection,
        )

        # apply the estimated similarity transform to Xt_init
        Xt = _apply_similarity_transform(Xt_init, R, T, s)

        # add the current transformation to the history
        t_history.append(SimilarityTransform(R, T, s))

        # compute the root mean squared error
        Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
        rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]

        # compute the relative rmse
        if prev_rmse is None:
            relative_rmse = rmse.new_ones(b)
        else:
            relative_rmse = (prev_rmse - rmse) / prev_rmse

        if verbose:
            rmse_msg = (
                f"ICP iteration {iteration}: mean/max rmse = "
                + f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
                + f"; mean relative rmse = {relative_rmse.mean():1.2e}"
            )
            print(rmse_msg)

        # check for convergence
        if (relative_rmse <= relative_rmse_thr).all():
            converged = True
            break

        # update the previous rmse
        prev_rmse = rmse

    if verbose:
        if converged:
            print(f"ICP has converged in {iteration + 1} iterations.")
        else:
            print(f"ICP has not converged in {max_iterations} iterations.")

    if oputil.is_pointclouds(X):
        Xt = X.update_padded(Xt)  # type: ignore

    return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
コード例 #27
0
    def mine_triplets(self, batch, true_edges, spatial, r_max, k_max):

        # -------- TRUTH
        torch_e = torch.sparse.FloatTensor(
            true_edges,
            torch.ones(true_edges.shape[1]).to(device),
            size=(len(spatial), len(spatial)),
        )
        sparse_sum = torch.sparse.sum(torch_e, dim=0)
        num_true_torch = torch.zeros(len(spatial)).to(device).int()
        num_true_torch[sparse_sum.indices()] = sparse_sum.values().int()
        sorted_true_indices = torch.argsort(true_edges[0])
        sorted_true_edges = true_edges[:, sorted_true_indices]

        # --------- HNM
        knn_object = ops.knn_points(spatial.unsqueeze(0),
                                    spatial.unsqueeze(0),
                                    K=k_max,
                                    return_sorted=False)
        I, D = knn_object.idx[0], knn_object.dists[0]

        # ---------- Shuffle
        shuffled_index = torch.randperm(I.shape[1])
        shuffled_I = I[:, shuffled_index]
        shuffled_D = D[:, shuffled_index]
        ind = torch.Tensor.repeat(
            torch.arange(shuffled_I.shape[0], device=device),
            (shuffled_I.shape[1], 1),
            1,
        ).T

        # ---------- Constraints
        shuffled_I[(shuffled_D > r_max) & (shuffled_D < (r_max / 2))] = -1
        shuffled_I[batch.pid[ind] == batch.pid[shuffled_I]] = -1
        shuffled_I[ind == shuffled_I] = -1

        # ----------- Reshape with -1's
        squished_I = push_all_negs_back(shuffled_I.cpu().numpy())
        squished_I = torch.from_numpy(squished_I).to(device)

        # ---------- Handle # pos > # neg
        pos_available = num_true_torch > 0
        squished_I = torch.cat(
            [
                squished_I,
                -1 * torch.ones(
                    squished_I.shape[0],
                    max(0,
                        num_true_torch.max() - k_max),
                    dtype=int,
                    device=device,
                ),
            ],
            axis=-1,
        )

        # ----------- Build Triplets
        selected_negatives = squished_I[pos_available][num_true_torch[
            pos_available,
            None] > torch.arange(squished_I.shape[1], device=device)]
        triplets = torch.cat(
            [sorted_true_edges,
             selected_negatives.unsqueeze(0)], axis=0)
        triplets = triplets[:, triplets[2] != -1]

        return triplets
コード例 #28
0
ファイル: rasterizer.py プロジェクト: yifita/DSS
    def _compute_isotropic_Vrk(self, pointclouds, refresh=True, **kwargs):
        """
        determine the variance in the local surface frame h * Sk.T @ Sk,
        where Sk is 2x3 local surface coordinate to world coordinate.
        determine the h_k in V_k^r = h_k*Id using nearest neighbor
        heuristically h_k = mean(dist between points in a small neighbor)
        The larger h_k is, the larger the splat is
        NOTE: h_k in inverse to the definition in the paper, the larger h_k, the
            larger the splats
        Args:
            pointclouds: pointcloud in object coordinate
        Returns:
            h_k: [N,3,3] tensor for each point
            S_k: [N,2,3] local frame
        """
        if not refresh and self._Vrk_h is not None and \
                pointclouds.num_points_per_cloud().sum() == self._Vrk_h.shape[0]:
            pass
        else:
            with torch.autograd.enable_grad():
                pts_world = pointclouds.points_padded()

            num_points_per_cloud = pointclouds.num_points_per_cloud()
            if self.frnn_radius <= 0:
                # logger_py.info("vrk knn points")
                sq_dist, _, _ = ops3d.knn_points(pts_world,
                                                 pts_world,
                                                 num_points_per_cloud,
                                                 num_points_per_cloud,
                                                 K=7)
            else:
                sq_dist, _, _, _ = frnn.frnn_grid_points(pts_world,
                                                         pts_world,
                                                         num_points_per_cloud,
                                                         num_points_per_cloud,
                                                         K=7,
                                                         r=self.frnn_radius)

            sq_dist = sq_dist[:, :, 1:]
            # knn search is unreliable, set sq_dist manually
            sq_dist[num_points_per_cloud < 7] = 1e-3
            # (totalP, knnK)
            sq_dist = ops3d.padded_to_packed(
                sq_dist, pointclouds.cloud_to_packed_first_idx(),
                num_points_per_cloud.sum().item())
            # [totalP, ]
            h_k = 0.5 * sq_dist.max(dim=-1, keepdim=True)[0]

            # prevent some outlier rendered be too large, or too small
            self._Vrk_h = h_k.clamp(5e-5, 0.01)

        # Sk, a transformation from 2D local surface frame to 3D world frame
        # Because isometry, two axis are equivalent, we can simply
        # find two 3d vectors perpendicular to the point normals
        # (totalP, 2, 3)
        with torch.autograd.enable_grad():
            normals = pointclouds.normals_packed()

        u0 = F.normalize(torch.cross(normals,
                                     normals + torch.rand_like(normals)),
                         dim=-1)
        u1 = F.normalize(torch.cross(normals, u0), dim=-1)
        Sk = torch.stack([u0, u1], dim=1)
        Vrk = self._Vrk_h.view(-1, 1, 1) * Sk.transpose(1, 2) @ Sk
        return Vrk, Sk
コード例 #29
0
ファイル: fitting.py プロジェクト: pgrady3/prox
    def forward(self,
                body_model_output,
                camera,
                gt_joints,
                joints_conf,
                body_model_faces,
                joint_weights,
                use_vposer=False,
                pose_embedding=None,
                scan_tensor=None,
                visualize=False,
                scene_v=None,
                scene_vn=None,
                scene_f=None,
                ftov=None,
                **kwargs):
        batch_size = gt_joints.shape[0]

        forehead_vert_id = 336  # Patrick: replace head joint with a vertex on the head
        model_joints = body_model_output.joints
        model_joints[:, 0, :] = body_model_output.vertices[:,
                                                           forehead_vert_id, :]

        projected_joints = camera(model_joints)
        # Calculate the weights for each joints
        weights = (joint_weights * joints_conf if self.use_joints_conf else
                   joint_weights).unsqueeze(dim=-1)

        # Calculate the distance of the projected joints from
        # the ground truth 2D detections
        joint_err = gt_joints - projected_joints
        joint_err_sq = joint_err.pow(2)
        # joint_diff = self.robustifier(gt_joints - projected_joints)
        joint_loss = (torch.sum(weights**2 * joint_err_sq, dim=[1, 2]) *
                      self.data_weight**2)

        # Calculate the loss from the Pose prior
        if use_vposer:
            pprior_loss = (pose_embedding.pow(2).sum() *
                           self.body_pose_weight**2)
        else:
            pprior_loss = self.body_pose_prior(
                body_model_output.body_pose,
                body_model_output.betas) * self.body_pose_weight**2

        shape_loss = self.shape_prior(
            body_model_output.betas) * self.shape_weight**2

        # Patrick - weight and height loss
        tmp_betas = body_model_output.betas
        tmp_gender = self.gender_tensor
        # print('Gender tensor', tmp_gender)
        batch_weight_est, batch_height_est = self.betanet(
            tmp_gender, tmp_betas)
        batch_height_est = 100 * batch_height_est
        # print('Height {} est {}'.format(self.height, batch_height_est.detach().item()))
        # print('Weight {} est {}'.format(self.weight, batch_weight_est.detach().item()))
        # print('Cur gender flag', tmp_gender)
        global_vars.cur_weight = batch_weight_est.detach().cpu().numpy()
        global_vars.cur_height = batch_height_est.detach().cpu().numpy()
        d_weight = self.weight - batch_weight_est.squeeze()
        d_height = self.height - batch_height_est.squeeze()
        physical_loss = d_weight.pow(2) * self.weight_w + d_height.pow(
            2) * self.height_w

        # Calculate the prior over the joint rotations. This a heuristic used
        # to prevent extreme rotation of the elbows and knees
        body_pose = body_model_output.full_pose[:, 3:66]

        # Patrick: turn off this loss
        angle_prior_loss = 0 * torch.sum(
            self.angle_prior(body_pose)) * self.bending_prior_weight**2

        # Apply the prior on the pose space of the hand
        left_hand_prior_loss, right_hand_prior_loss = 0.0, 0.0
        if self.use_hands and self.left_hand_prior is not None:
            left_hand_prior_loss = torch.sum(
                self.left_hand_prior(
                    body_model_output.left_hand_pose)) * \
                self.hand_prior_weight ** 2

        if self.use_hands and self.right_hand_prior is not None:
            right_hand_prior_loss = torch.sum(
                self.right_hand_prior(
                    body_model_output.right_hand_pose)) * \
                self.hand_prior_weight ** 2

        expression_loss = 0.0
        jaw_prior_loss = 0.0
        if self.use_face:
            expression_loss = torch.sum(self.expr_prior(
                body_model_output.expression)) * \
                self.expr_prior_weight ** 2

            if hasattr(self, 'jaw_prior'):
                jaw_prior_loss = torch.sum(
                    self.jaw_prior(
                        body_model_output.jaw_pose.mul(self.jaw_prior_weight)))

        pen_loss = 0.0
        # Calculate the loss due to interpenetration
        if (self.interpenetration and self.coll_loss_weight.item() > 0):
            batch_size = projected_joints.shape[0]
            triangles = torch.index_select(body_model_output.vertices, 1,
                                           body_model_faces.view(-1)).view(
                                               batch_size, -1, 3, 3)

            with torch.no_grad():
                collision_idxs = self.search_tree(triangles)

            # Remove unwanted collisions
            if self.tri_filtering_module is not None:
                collision_idxs = self.tri_filtering_module(collision_idxs)

            if collision_idxs.ge(0).sum().item() > 0:
                pen_loss = self.coll_loss_weight * self.pen_distance(
                    triangles, collision_idxs)

        s2m_dist = 0.0
        m2s_dist = 0.0
        # calculate the scan2mesh and mesh2scan loss from the sparse point cloud
        if (self.s2m or self.m2s) and (self.s2m_weight > 0 or self.m2s_weight >
                                       0) and scan_tensor is not None:
            # vertices_np = body_model_output.vertices.detach().cpu().numpy().squeeze()
            # body_faces_np = body_model_faces.detach().cpu().numpy().reshape(-1, 3)
            # m = Mesh(v=vertices_np, f=body_faces_np)
            #
            # (vis, n_dot) = visibility_compute(v=m.v, f=m.f, cams=np.array([[0.0, 0.0, 0.0]]))
            # vis = vis.squeeze()

            # TODO ignore body mask?
            # TODO ignore visibility map? Would help with points getting stuck inside body, only care about normal
            # Maybe need to do this in packed?

            in_mesh_verts = body_model_output.vertices
            in_mesh_faces = body_model_faces.unsqueeze(0).repeat(
                batch_size, 1, 1)
            body_mesh = pytorch3d.structures.Meshes(verts=in_mesh_verts,
                                                    faces=in_mesh_faces)
            mesh_normals = body_mesh.verts_normals_padded()
            mesh_verts = body_mesh.verts_padded()
            camera_pos = torch.tensor([0.0, 0.0, 0.0],
                                      device=mesh_verts.device)
            vec_mesh_to_cam = camera_pos - mesh_verts
            towards_camera = torch.sum(mesh_normals * vec_mesh_to_cam,
                                       dim=2) > 0
            num_mesh_verts_towards_camera = torch.sum(towards_camera, dim=1)

            mesh_verts_towards_camera = torch.zeros(
                [batch_size,
                 num_mesh_verts_towards_camera.max(), 3],
                device=mesh_verts.device)
            for i in range(
                    batch_size):  # There's probably a cleaner way to do this
                mesh_verts_towards_camera[
                    i, 0:num_mesh_verts_towards_camera[i], :] = mesh_verts[
                        i, towards_camera[i, :], :]

            num_scan_points = scan_tensor.num_points_per_cloud(
            )  # Get number of points in mesh
            if self.s2m and self.s2m_weight > 0:
                # Note, returns squared distance
                s2m_dist = knn_points(scan_tensor.points_padded(),
                                      mesh_verts_towards_camera,
                                      lengths1=num_scan_points,
                                      lengths2=num_mesh_verts_towards_camera,
                                      K=1)

                # s2m_dist, _, _, _ = distChamfer(scan_tensor, body_model_output.vertices[:, np.where(vis > 0)[0], :])
                s2m_dist = self.s2m_robustifier(s2m_dist.dists.sqrt())
                s2m_dist = self.s2m_weight * s2m_dist.sum(dim=[1, 2])
            if self.m2s and self.m2s_weight > 0:
                # _, m2s_dist, _, _ = distChamfer(scan_tensor, body_model_output.vertices[:, np.where(np.logical_and(vis > 0, self.body_mask))[0], :])
                m2s_dist = knn_points(mesh_verts_towards_camera,
                                      scan_tensor.points_padded(),
                                      lengths2=num_scan_points,
                                      lengths1=num_mesh_verts_towards_camera,
                                      K=1)

                m2s_dist = self.m2s_robustifier(m2s_dist.dists.sqrt())
                m2s_dist = self.m2s_weight * m2s_dist.sum(dim=[1, 2])

        # Transform vertices to world coordinates
        if self.R is not None and self.t is not None:
            vertices = body_model_output.vertices.view(-1, 3)
            nv = vertices.shape[0]
            vertices = self.R.mm(vertices.t()).t() + self.t.repeat([nv, 1])
            vertices = vertices.reshape(batch_size, -1, 3)

        # Compute scene penetration using signed distance field (SDF)
        # sdf_penetration_loss = 0.0
        # if self.sdf_penetration and self.sdf_penetration_weight > 0:
        #     grid_dim = self.sdf.shape[0]
        #     sdf_ids = torch.round((vertices.squeeze() - self.grid_min) / self.voxel_size).to(dtype=torch.long)  # Convert SMPL vertex to closest voxel ID
        #     sdf_ids.clamp_(min=0, max=grid_dim-1)   # Clamp to limits of grid
        #
        #     norm_vertices = (vertices - self.grid_min) / (self.grid_max - self.grid_min) * 2 - 1    # Put SMPL verts into voxel space
        #     body_sdf = F.grid_sample(self.sdf.view(1, 1, grid_dim, grid_dim, grid_dim),
        #                              norm_vertices[:, :, [2, 1, 0]].view(1, nv, 1, 1, 3),
        #                              padding_mode='border')     # Calculate SDF for each SMPL vertex
        #     sdf_normals = self.sdf_normals[sdf_ids[:,0], sdf_ids[:,1], sdf_ids[:,2]]    # Find the SDF normal for each SMPL vertex
        #     # if there are no penetrating vertices then set sdf_penetration_loss = 0
        #     if body_sdf.lt(0).sum().item() < 1:
        #         sdf_penetration_loss = torch.tensor(0.0, dtype=joint_loss.dtype, device=joint_loss.device)
        #     else:
        #       if sdf_normals is None:
        #         sdf_penetration_loss = self.sdf_penetration_weight * (body_sdf[body_sdf < 0].unsqueeze(dim=-1).abs()).pow(2).sum(dim=-1).sqrt().sum()
        #       else:
        #         sdf_penetration_loss = self.sdf_penetration_weight * (body_sdf[body_sdf < 0].unsqueeze(dim=-1).abs() * sdf_normals[body_sdf.view(-1) < 0, :]).pow(2).sum(dim=-1).sqrt().sum()

        sdf_penetration_loss = 0.0
        if self.sdf_penetration and self.sdf_penetration_weight > 0:
            # Bed is at +z 2150, pointing down
            bed_height = 2.150

            body_sdf = bed_height - vertices[:, :,
                                             2]  # Less than zero inside bed
            sdf_normals = torch.zeros_like(vertices)
            sdf_normals[:, :, 2] = -1

            # if there are no penetrating vertices then set sdf_penetration_loss = 0
            if body_sdf.lt(0).sum().item() < 1:
                sdf_penetration_loss = torch.tensor(0.0,
                                                    dtype=joint_loss.dtype,
                                                    device=joint_loss.device)
            else:
                sel_sdf = torch.max(body_sdf, torch.zeros_like(body_sdf))
                sdf_dot = sel_sdf.unsqueeze(2) * sdf_normals
                sdf_penetration_loss = self.sdf_penetration_weight * sdf_dot.pow(
                    2).sum(dim=2).sqrt().sum(dim=1)

        # Compute the contact loss
        contact_loss = 0.0
        if self.contact and self.contact_loss_weight > 0:
            # select contact vertices
            contact_body_vertices = vertices[:, self.contact_verts_ids, :]
            contact_dist, _, idx1, _ = distChamfer(
                contact_body_vertices.contiguous(), scene_v)

            body_triangles = torch.index_select(vertices, 1,
                                                body_model_faces).view(
                                                    1, -1, 3, 3)
            # Calculate the edges of the triangles
            # Size: BxFx3
            edge0 = body_triangles[:, :, 1] - body_triangles[:, :, 0]
            edge1 = body_triangles[:, :, 2] - body_triangles[:, :, 0]
            # Compute the cross product of the edges to find the normal vector of
            # the triangle
            body_normals = torch.cross(edge0, edge1, dim=2)
            # Normalize the result to get a unit vector
            body_normals = body_normals / \
                torch.norm(body_normals, 2, dim=2, keepdim=True)
            # compute the vertex normals
            body_v_normals = torch.mm(ftov, body_normals.squeeze())
            body_v_normals = body_v_normals / \
                torch.norm(body_v_normals, 2, dim=1, keepdim=True)

            # vertix normals of contact vertices
            contact_body_verts_normals = body_v_normals[
                self.contact_verts_ids, :]
            # scene normals of the closest points on the scene surface to the contact vertices
            contact_scene_normals = scene_vn[:,
                                             idx1.squeeze().to(dtype=torch.long
                                                               ), :].squeeze()

            # compute the angle between contact_verts normals and scene normals
            angles = torch.asin(
                torch.norm(torch.cross(contact_body_verts_normals,
                                       contact_scene_normals),
                           2,
                           dim=1,
                           keepdim=True)) * 180 / np.pi

            # consider only the vertices which their normals match
            valid_contact_mask = (angles.le(self.contact_angle) +
                                  angles.ge(180 - self.contact_angle)).ge(1)
            valid_contact_ids = valid_contact_mask.squeeze().nonzero().squeeze(
            )

            contact_dist = self.contact_robustifier(
                contact_dist[:, valid_contact_ids].sqrt())
            contact_loss = self.contact_loss_weight * contact_dist.mean()

        total_loss = (joint_loss + pprior_loss + shape_loss +
                      angle_prior_loss + pen_loss + jaw_prior_loss +
                      expression_loss + left_hand_prior_loss +
                      right_hand_prior_loss + m2s_dist + s2m_dist +
                      sdf_penetration_loss + contact_loss + physical_loss)

        loss_dict = {
            'total': total_loss,
            'joint': joint_loss,
            's2m': s2m_dist,
            'm2s': m2s_dist,
            'pprior': pprior_loss,
            'shape': shape_loss,
            'angle_prior': angle_prior_loss,
            'pen': pen_loss,
            'sdf_penetration': sdf_penetration_loss,
            'contact': contact_loss,
            'physical': physical_loss
        }

        global_vars.cur_loss_dict = dict()
        for key, item in loss_dict.items():
            if torch.is_tensor(item):
                global_vars.cur_loss_dict[key] = item.detach().cpu().numpy()
            else:
                global_vars.cur_loss_dict[key] = item

        if visualize:
            np.set_printoptions(precision=1)
            for key, value in global_vars.cur_loss_dict.items():
                if isinstance(value, np.ndarray):
                    if value.sum() == 0:
                        continue
                else:
                    if value == 0:
                        continue
                print('{}:{}'.format(key, global_vars.cur_loss_dict[key]),
                      end=' ')
            print('max_joint:{}'.format(global_vars.cur_max_joint))

        return total_loss.sum()
コード例 #30
0
def estimate_pointcloud_local_coord_frames(
    pointclouds: Union[torch.Tensor, Pointclouds],
    neighborhood_size: int = 50,
    disambiguate_directions: bool = True,
    return_knn_result: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor, Optional['KNN']]:
    """
    Faster version of pytorch3d estimate_pointcloud_local_coord_frames

    Estimates the principal directions of curvature (which includes normals)
    of a batch of `pointclouds`.
    Returns:
        curvatures (N,P,3) ascending order
        local_frames (N,P,3,3) corresponding eigenvectors
    """
    points_padded, num_points = convert_pointclouds_to_tensor(pointclouds)

    ba, N, dim = points_padded.shape
    if dim != 3:
        raise ValueError(
            "The pointclouds argument has to be of shape (minibatch, N, 3)")

    if (num_points <= neighborhood_size).any():
        raise ValueError("The neighborhood_size argument has to be" +
                         " >= size of each of the point clouds.")
    # undo global mean for stability
    # TODO: replace with tutil.wmean once landed
    pcl_mean = points_padded.sum(1) / num_points[:, None]
    points_centered = points_padded - pcl_mean[:, None, :]

    # get K nearest neighbor idx for each point in the point cloud
    knn_result = knn_points(
        points_padded,
        points_padded,
        lengths1=num_points,
        lengths2=num_points,
        K=neighborhood_size,
        return_nn=True,
    )
    k_nearest_neighbors = knn_result.knn
    # obtain the mean of the neighborhood
    pt_mean = k_nearest_neighbors.mean(2, keepdim=True)
    # compute the diff of the neighborhood and the mean of the neighborhood
    # N,P,K,3
    central_diff = k_nearest_neighbors - pt_mean
    per_pts_diff = central_diff.view(-1, neighborhood_size, 3)
    # S (NP,3) and local_coord_framds (NP,3,3)
    _, S, local_coord_frames = batch_svd(per_pts_diff)
    curvature = S * S / neighborhood_size
    local_coord_frames = local_coord_frames.view(ba, N, dim, dim)
    curvature = curvature.view(ba, N, dim)

    # flip to ascending order
    curvature = curvature.flip(-1)
    local_coord_frames = local_coord_frames.flip(-1)

    # disambiguate the directions of individual principal vectors
    if disambiguate_directions:
        # disambiguate normal
        n = _disambiguate_vector_directions(points_centered,
                                            k_nearest_neighbors,
                                            local_coord_frames[:, :, :, 0])
        # disambiguate the main curvature
        z = _disambiguate_vector_directions(points_centered,
                                            k_nearest_neighbors,
                                            local_coord_frames[:, :, :, 2])
        # the secondary curvature is just a cross between n and z
        y = torch.cross(n, z, dim=2)
        # cat to form the set of principal directions
        local_coord_frames = torch.stack((n, y, z), dim=3)

    if return_knn_result:
        return curvature, local_coord_frames, knn_result
    return curvature, local_coord_frames