Beispiel #1
0
def sample_and_group(npoint, nsample, xyz, points):
    B, N, C = xyz.shape
    S = npoint
    # xyz = xyz.contiguous()
    sampler = FurthestPointSampler(npoint)
    _, fps_idx = sampler(xyz)  # [B, npoint]
    # print ('fps size=', fps_idx.size())
    # fps_idx = sampler(xyz).long() # [B, npoint]
    new_xyz = index_points(xyz, fps_idx)
    new_points = index_points(points, fps_idx)
    # new_xyz = xyz[:]
    # new_points = points[:]

    idx = knn_point(nsample, xyz, new_xyz)
    #idx = query_ball_point(radius, nsample, xyz, new_xyz)
    grouped_xyz = index_points(xyz, idx)  # [B, npoint, nsample, C]
    grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
    grouped_points = index_points(points, idx)
    grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
    new_points = concat([
        grouped_points_norm,
        new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)
    ],
                        dim=-1)
    return new_xyz, new_points
Beispiel #2
0
 def __init__(self, C_in: int, C_out: int, dims: int, K: int, D: int,
              P: int) -> None:
     """ See documentation for PointCNN. """
     super(RandPointCNN, self).__init__()
     self.pointcnn = PointCNN(C_in, C_out, dims, K, D, P)
     self.P = P
     if self.P > 0:
         self.sampler = FurthestPointSampler(self.P)
Beispiel #3
0
    def __init__(self, n_points: int, radius: List[float], n_samples: List[int],
                 mlps: List[List[int]], bn=True, use_xyz=True):
        super().__init__()

        self.n_points = n_points
        self.sampler = FurthestPointSampler(n_points)

        self.groupers = nn.ModuleList()
        for r, s in zip(radius, n_samples):
            self.groupers.append(BallQueryGrouper(r, s, use_xyz))

        self.mlps = nn.ModuleList()
        for mlp in mlps.layers.items():
            self.mlps.append(self.build_mlps(mlp, use_xyz))
Beispiel #4
0
    def __init__(self, mlp: List[int], n_points=None, radius=None, 
                 n_samples=None, bn=True, use_xyz=True):
        super().__init__()

        self.n_points = n_points

        self.groupers = nn.ModuleList()
        if self.n_points is not None:
            self.sampler = FurthestPointSampler(n_points)
            self.groupers.append(BallQueryGrouper(radius, n_samples, use_xyz))
        else:
            self.groupers.append(GroupAll(use_xyz))

        self.mlps = nn.ModuleList()
        self.mlps.append(self.build_mlps(mlp, use_xyz))