Esempio n. 1
0
def furthest_point_sample(xyz, npoint):
    # type: (Any, torch.Tensor, int) -> torch.Tensor
    r"""
    Uses iterative furthest point sampling to select a set of npoint features that have the largest
    minimum distance

    Parameters
    ----------
    xyz : torch.Tensor
        (B, N, 3) tensor where N > npoint
    npoint : int32
        number of features in the sampled set

    Returns
    -------
    torch.Tensor
        (B, npoint) tensor containing the set
    """
    if npoint > xyz.shape[1]:
        raise ValueError(
            "caanot sample %i points from an input set of %i points" %
            (npoint, xyz.shape[1]))
    if xyz.is_cuda:
        return tpcuda.furthest_point_sampling(xyz, npoint)
    else:
        return tpcpu.fps(xyz, npoint, True)
    def _multiscale_compute_fn(self,
                               batch,
                               collate_fn=None,
                               precompute_multi_scale=False,
                               num_scales=0,
                               sample_method='random'):
        batch = collate_fn(batch)
        if not precompute_multi_scale:
            return batch
        multiscale = []
        pos = batch.pos     # [B, N, 3]
        for i in range(num_scales):
            neighbor_idx = self._knn_search(pos, pos, self.kernel_size[i])      # [B, N, K]
            sample_num = pos.shape[1] // self.ratio[i]
            if sample_method.lower() == 'random':
                choice = torch.randperm(pos.shape[1])[:sample_num]
                sub_pos = pos[:, choice, :]             # random sampled pos   [B, S, 3]
                sub_idx = neighbor_idx[:, choice, :]    # the pool idx  [B, S, K]
            elif sample_method.lower() == 'fps':
                choice = tpcuda.furthest_point_sampling(pos.cuda(), sample_num).to(torch.long).cpu()
                sub_pos = pos.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, pos.shape[-1]))
                sub_idx = neighbor_idx.gather(dim=1, index=choice.unsqueeze(-1).repeat(1, 1, neighbor_idx.shape[-1]))
            else:
                raise NotImplementedError('Only `random` or `fps` sampling method is implemented!')

            up_idx = self._knn_search(sub_pos, pos, 1)      # [B, N, 1]
            multiscale.append(Data(pos=pos, neighbor_idx=neighbor_idx, sub_idx=sub_idx, up_idx=up_idx))
            pos = sub_pos

        return MultiScaleData(x=batch.x,
                              y=batch.y,
                              point_idx=batch.point_idx,
                              cloud_idx=batch.cloud_idx,
                              multiscale=multiscale)